tambahannya
This commit is contained in:
245
xclone.py
Normal file
245
xclone.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import base64
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torchaudio as ta
|
||||
import torchaudio.functional as F
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# =========================
|
||||
# KONFIGURASI MODEL - DISESUAIKAN UNTUK NATURALNESS
|
||||
# =========================
|
||||
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
|
||||
CHECKPOINT = "t3_cfg.safetensors"
|
||||
DEVICE = "cpu"
|
||||
|
||||
# Parameter dioptimasi untuk suara lebih natural dan mirip source
|
||||
TEMPERATURE = 0.65
|
||||
TOP_P = 0.88
|
||||
REPETITION_PENALTY = 1.25
|
||||
AUDIO_GAIN_DB = 0.8
|
||||
|
||||
PROMPT_FOLDER = "prompt_source"
|
||||
os.makedirs(PROMPT_FOLDER, exist_ok=True)
|
||||
|
||||
# =========================
|
||||
# Enhance audio dengan fokus pada naturalness
|
||||
# =========================
|
||||
def enhance_audio(wav, sr):
|
||||
peak = wav.abs().max()
|
||||
if peak > 0:
|
||||
wav = wav / (peak + 1e-8) * 0.95
|
||||
wav = F.highpass_biquad(wav, sr, cutoff_freq=60)
|
||||
wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000)
|
||||
wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7)
|
||||
wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7)
|
||||
|
||||
threshold = 0.6
|
||||
ratio = 2.5
|
||||
knee = 0.1
|
||||
abs_wav = wav.abs()
|
||||
mask_hard = abs_wav > (threshold + knee)
|
||||
mask_knee = (abs_wav > (threshold - knee)) & (abs_wav <= (threshold + knee))
|
||||
compressed = torch.where(
|
||||
mask_hard,
|
||||
torch.sign(wav) * (threshold + (abs_wav - threshold) / ratio),
|
||||
wav
|
||||
)
|
||||
knee_factor = ((abs_wav - (threshold - knee)) / (2 * knee)) ** 2
|
||||
knee_compressed = torch.sign(wav) * (
|
||||
threshold - knee + knee_factor * (2 * knee) + (abs_wav - threshold) / ratio * knee_factor
|
||||
)
|
||||
compressed = torch.where(mask_knee, knee_compressed, compressed)
|
||||
wav = compressed
|
||||
|
||||
saturation_amount = 0.08
|
||||
wav = torch.tanh(wav * (1 + saturation_amount)) / (1 + saturation_amount)
|
||||
|
||||
pink_noise = generate_pink_noise(wav.shape, wav.device) * 0.0003
|
||||
wav = wav + pink_noise
|
||||
|
||||
wav = torch.tanh(wav * 1.1) * 0.92
|
||||
wav = F.gain(wav, gain_db=AUDIO_GAIN_DB)
|
||||
|
||||
peak = wav.abs().max().item()
|
||||
if peak > 0:
|
||||
wav = wav / peak * 0.88
|
||||
return wav
|
||||
|
||||
def generate_pink_noise(shape, device):
|
||||
white = torch.randn(shape, device=device)
|
||||
if len(shape) > 1:
|
||||
pink = torch.zeros_like(white)
|
||||
for i in range(shape[0]):
|
||||
b = torch.zeros(7)
|
||||
for j in range(shape[1]):
|
||||
white_val = white[i, j].item()
|
||||
b[0] = 0.99886 * b[0] + white_val * 0.0555179
|
||||
b[1] = 0.99332 * b[1] + white_val * 0.0750759
|
||||
b[2] = 0.96900 * b[2] + white_val * 0.1538520
|
||||
b[3] = 0.86650 * b[3] + white_val * 0.3104856
|
||||
b[4] = 0.55000 * b[4] + white_val * 0.5329522
|
||||
b[5] = -0.7616 * b[5] - white_val * 0.0168980
|
||||
pink[i, j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11
|
||||
b[6] = white_val * 0.115926
|
||||
else:
|
||||
pink = torch.zeros_like(white)
|
||||
b = torch.zeros(7)
|
||||
for j in range(shape[0]):
|
||||
white_val = white[j].item()
|
||||
b[0] = 0.99886 * b[0] + white_val * 0.0555179
|
||||
b[1] = 0.99332 * b[1] + white_val * 0.0750759
|
||||
b[2] = 0.96900 * b[2] + white_val * 0.1538520
|
||||
b[3] = 0.86650 * b[3] + white_val * 0.3104856
|
||||
b[4] = 0.55000 * b[4] + white_val * 0.5329522
|
||||
b[5] = -0.7616 * b[5] - white_val * 0.0168980
|
||||
pink[j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11
|
||||
b[6] = white_val * 0.115926
|
||||
return pink * 0.1
|
||||
|
||||
# Load model sekali
|
||||
print("Loading model...")
|
||||
|
||||
model = ChatterboxTTS.from_pretrained(device=DEVICE)
|
||||
ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT)
|
||||
state = load_file(ckpt, device=DEVICE)
|
||||
|
||||
model.t3.to(DEVICE).load_state_dict(state)
|
||||
model.t3.eval()
|
||||
|
||||
for module in model.t3.modules():
|
||||
if hasattr(module, "training"):
|
||||
module.training = False
|
||||
|
||||
for module in model.t3.modules():
|
||||
if isinstance(module, torch.nn.Dropout):
|
||||
module.p = 0
|
||||
|
||||
print("Model ready with enhanced settings.")
|
||||
|
||||
# ======= Fungsi CLI =======
|
||||
def register_prompt_base64(prompt_name, base64_audio):
|
||||
filename = f"{prompt_name}.wav"
|
||||
path = os.path.join(PROMPT_FOLDER, filename)
|
||||
raw = base64.b64decode(base64_audio)
|
||||
with open(path, "wb") as f:
|
||||
f.write(raw)
|
||||
print(f"Registered base64 prompt as {filename}")
|
||||
|
||||
def register_prompt_file(src_path, prompt_name=None):
|
||||
if prompt_name is None:
|
||||
prompt_name = os.path.splitext(os.path.basename(src_path))[0]
|
||||
save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
|
||||
with open(src_path, "rb") as src_file:
|
||||
data = src_file.read()
|
||||
with open(save_path, "wb") as dst_file:
|
||||
dst_file.write(data)
|
||||
print(f"Registered prompt file as {prompt_name}.wav")
|
||||
|
||||
def list_prompt():
|
||||
files = os.listdir(PROMPT_FOLDER)
|
||||
wav_files = [f for f in files if f.lower().endswith(".wav")]
|
||||
print(f"Total prompts: {len(wav_files)}")
|
||||
for f in wav_files:
|
||||
print(f"- {f}")
|
||||
|
||||
def delete_prompt(prompt_name):
|
||||
file_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
|
||||
if not os.path.exists(file_path):
|
||||
print(f"Prompt '{prompt_name}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
os.remove(file_path)
|
||||
print(f"Deleted prompt {prompt_name}.wav")
|
||||
|
||||
def rename_prompt(old_name, new_name):
|
||||
old_path = os.path.join(PROMPT_FOLDER, f"{old_name}.wav")
|
||||
new_path = os.path.join(PROMPT_FOLDER, f"{new_name}.wav")
|
||||
if not os.path.exists(old_path):
|
||||
print(f"Old prompt '{old_name}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if os.path.exists(new_path):
|
||||
print(f"New prompt name '{new_name}' already exists.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
os.rename(old_path, new_path)
|
||||
print(f"Renamed prompt '{old_name}' to '{new_name}'")
|
||||
|
||||
def tts(text, prompt, output_path="output.wav", temperature=TEMPERATURE, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY):
|
||||
prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav")
|
||||
if not os.path.exists(prompt_path):
|
||||
print(f"Prompt '{prompt}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
try:
|
||||
wav = model.generate(
|
||||
text,
|
||||
audio_prompt_path=prompt_path,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to generate audio: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
wav = enhance_audio(wav.cpu(), model.sr)
|
||||
ta.save(output_path, wav, model.sr, format="wav")
|
||||
print(f"TTS output saved to {output_path}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Chatterbox TTS CLI")
|
||||
subparsers = parser.add_subparsers(dest="command")
|
||||
|
||||
# register base64
|
||||
p_register_base64 = subparsers.add_parser("register-base64", help="Register prompt from base64 string")
|
||||
p_register_base64.add_argument("prompt_name")
|
||||
p_register_base64.add_argument("base64_audio")
|
||||
|
||||
# register file
|
||||
p_register_file = subparsers.add_parser("register-file", help="Register prompt from wav file")
|
||||
p_register_file.add_argument("src_path")
|
||||
p_register_file.add_argument("--name", default=None)
|
||||
|
||||
# list prompt
|
||||
p_list = subparsers.add_parser("list", help="List all prompt wav files")
|
||||
|
||||
# delete prompt
|
||||
p_delete = subparsers.add_parser("delete", help="Delete prompt wav file")
|
||||
p_delete.add_argument("prompt_name")
|
||||
|
||||
# rename prompt
|
||||
p_rename = subparsers.add_parser("rename", help="Rename prompt wav file")
|
||||
p_rename.add_argument("old_name")
|
||||
p_rename.add_argument("new_name")
|
||||
|
||||
# tts generate
|
||||
p_tts = subparsers.add_parser("tts", help="Generate TTS wav file")
|
||||
p_tts.add_argument("text")
|
||||
p_tts.add_argument("prompt")
|
||||
p_tts.add_argument("--output", default="output.wav")
|
||||
p_tts.add_argument("--temperature", type=float, default=TEMPERATURE)
|
||||
p_tts.add_argument("--top_p", type=float, default=TOP_P)
|
||||
p_tts.add_argument("--repetition_penalty", type=float, default=REPETITION_PENALTY)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "register-base64":
|
||||
register_prompt_base64(args.prompt_name, args.base64_audio)
|
||||
elif args.command == "register-file":
|
||||
register_prompt_file(args.src_path, args.name)
|
||||
elif args.command == "list":
|
||||
list_prompt()
|
||||
elif args.command == "delete":
|
||||
delete_prompt(args.prompt_name)
|
||||
elif args.command == "rename":
|
||||
rename_prompt(args.old_name, args.new_name)
|
||||
elif args.command == "tts":
|
||||
tts(args.text, args.prompt, args.output, args.temperature, args.top_p, args.repetition_penalty)
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user