246 lines
8.7 KiB
Python
246 lines
8.7 KiB
Python
import io
|
|
import os
|
|
import sys
|
|
import base64
|
|
import argparse
|
|
|
|
import torch
|
|
import torchaudio as ta
|
|
import torchaudio.functional as F
|
|
from chatterbox.tts import ChatterboxTTS
|
|
from huggingface_hub import hf_hub_download
|
|
from safetensors.torch import load_file
|
|
|
|
# =========================
|
|
# KONFIGURASI MODEL - DISESUAIKAN UNTUK NATURALNESS
|
|
# =========================
|
|
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
|
|
CHECKPOINT = "t3_cfg.safetensors"
|
|
DEVICE = "cpu"
|
|
|
|
# Parameter dioptimasi untuk suara lebih natural dan mirip source
|
|
TEMPERATURE = 0.65
|
|
TOP_P = 0.88
|
|
REPETITION_PENALTY = 1.25
|
|
AUDIO_GAIN_DB = 0.8
|
|
|
|
PROMPT_FOLDER = "prompt_source"
|
|
os.makedirs(PROMPT_FOLDER, exist_ok=True)
|
|
|
|
# =========================
|
|
# Enhance audio dengan fokus pada naturalness
|
|
# =========================
|
|
def enhance_audio(wav, sr):
|
|
peak = wav.abs().max()
|
|
if peak > 0:
|
|
wav = wav / (peak + 1e-8) * 0.95
|
|
wav = F.highpass_biquad(wav, sr, cutoff_freq=60)
|
|
wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000)
|
|
wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7)
|
|
wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7)
|
|
|
|
threshold = 0.6
|
|
ratio = 2.5
|
|
knee = 0.1
|
|
abs_wav = wav.abs()
|
|
mask_hard = abs_wav > (threshold + knee)
|
|
mask_knee = (abs_wav > (threshold - knee)) & (abs_wav <= (threshold + knee))
|
|
compressed = torch.where(
|
|
mask_hard,
|
|
torch.sign(wav) * (threshold + (abs_wav - threshold) / ratio),
|
|
wav
|
|
)
|
|
knee_factor = ((abs_wav - (threshold - knee)) / (2 * knee)) ** 2
|
|
knee_compressed = torch.sign(wav) * (
|
|
threshold - knee + knee_factor * (2 * knee) + (abs_wav - threshold) / ratio * knee_factor
|
|
)
|
|
compressed = torch.where(mask_knee, knee_compressed, compressed)
|
|
wav = compressed
|
|
|
|
saturation_amount = 0.08
|
|
wav = torch.tanh(wav * (1 + saturation_amount)) / (1 + saturation_amount)
|
|
|
|
pink_noise = generate_pink_noise(wav.shape, wav.device) * 0.0003
|
|
wav = wav + pink_noise
|
|
|
|
wav = torch.tanh(wav * 1.1) * 0.92
|
|
wav = F.gain(wav, gain_db=AUDIO_GAIN_DB)
|
|
|
|
peak = wav.abs().max().item()
|
|
if peak > 0:
|
|
wav = wav / peak * 0.88
|
|
return wav
|
|
|
|
def generate_pink_noise(shape, device):
|
|
white = torch.randn(shape, device=device)
|
|
if len(shape) > 1:
|
|
pink = torch.zeros_like(white)
|
|
for i in range(shape[0]):
|
|
b = torch.zeros(7)
|
|
for j in range(shape[1]):
|
|
white_val = white[i, j].item()
|
|
b[0] = 0.99886 * b[0] + white_val * 0.0555179
|
|
b[1] = 0.99332 * b[1] + white_val * 0.0750759
|
|
b[2] = 0.96900 * b[2] + white_val * 0.1538520
|
|
b[3] = 0.86650 * b[3] + white_val * 0.3104856
|
|
b[4] = 0.55000 * b[4] + white_val * 0.5329522
|
|
b[5] = -0.7616 * b[5] - white_val * 0.0168980
|
|
pink[i, j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11
|
|
b[6] = white_val * 0.115926
|
|
else:
|
|
pink = torch.zeros_like(white)
|
|
b = torch.zeros(7)
|
|
for j in range(shape[0]):
|
|
white_val = white[j].item()
|
|
b[0] = 0.99886 * b[0] + white_val * 0.0555179
|
|
b[1] = 0.99332 * b[1] + white_val * 0.0750759
|
|
b[2] = 0.96900 * b[2] + white_val * 0.1538520
|
|
b[3] = 0.86650 * b[3] + white_val * 0.3104856
|
|
b[4] = 0.55000 * b[4] + white_val * 0.5329522
|
|
b[5] = -0.7616 * b[5] - white_val * 0.0168980
|
|
pink[j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11
|
|
b[6] = white_val * 0.115926
|
|
return pink * 0.1
|
|
|
|
# Load model sekali
|
|
print("Loading model...")
|
|
|
|
model = ChatterboxTTS.from_pretrained(device=DEVICE)
|
|
ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT)
|
|
state = load_file(ckpt, device=DEVICE)
|
|
|
|
model.t3.to(DEVICE).load_state_dict(state)
|
|
model.t3.eval()
|
|
|
|
for module in model.t3.modules():
|
|
if hasattr(module, "training"):
|
|
module.training = False
|
|
|
|
for module in model.t3.modules():
|
|
if isinstance(module, torch.nn.Dropout):
|
|
module.p = 0
|
|
|
|
print("Model ready with enhanced settings.")
|
|
|
|
# ======= Fungsi CLI =======
|
|
def register_prompt_base64(prompt_name, base64_audio):
|
|
filename = f"{prompt_name}.wav"
|
|
path = os.path.join(PROMPT_FOLDER, filename)
|
|
raw = base64.b64decode(base64_audio)
|
|
with open(path, "wb") as f:
|
|
f.write(raw)
|
|
print(f"Registered base64 prompt as {filename}")
|
|
|
|
def register_prompt_file(src_path, prompt_name=None):
|
|
if prompt_name is None:
|
|
prompt_name = os.path.splitext(os.path.basename(src_path))[0]
|
|
save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
|
|
with open(src_path, "rb") as src_file:
|
|
data = src_file.read()
|
|
with open(save_path, "wb") as dst_file:
|
|
dst_file.write(data)
|
|
print(f"Registered prompt file as {prompt_name}.wav")
|
|
|
|
def list_prompt():
|
|
files = os.listdir(PROMPT_FOLDER)
|
|
wav_files = [f for f in files if f.lower().endswith(".wav")]
|
|
print(f"Total prompts: {len(wav_files)}")
|
|
for f in wav_files:
|
|
print(f"- {f}")
|
|
|
|
def delete_prompt(prompt_name):
|
|
file_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
|
|
if not os.path.exists(file_path):
|
|
print(f"Prompt '{prompt_name}' not found.", file=sys.stderr)
|
|
sys.exit(1)
|
|
os.remove(file_path)
|
|
print(f"Deleted prompt {prompt_name}.wav")
|
|
|
|
def rename_prompt(old_name, new_name):
|
|
old_path = os.path.join(PROMPT_FOLDER, f"{old_name}.wav")
|
|
new_path = os.path.join(PROMPT_FOLDER, f"{new_name}.wav")
|
|
if not os.path.exists(old_path):
|
|
print(f"Old prompt '{old_name}' not found.", file=sys.stderr)
|
|
sys.exit(1)
|
|
if os.path.exists(new_path):
|
|
print(f"New prompt name '{new_name}' already exists.", file=sys.stderr)
|
|
sys.exit(1)
|
|
os.rename(old_path, new_path)
|
|
print(f"Renamed prompt '{old_name}' to '{new_name}'")
|
|
|
|
def tts(text, prompt, output_path="output.wav", temperature=TEMPERATURE, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY):
|
|
prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav")
|
|
if not os.path.exists(prompt_path):
|
|
print(f"Prompt '{prompt}' not found.", file=sys.stderr)
|
|
sys.exit(1)
|
|
try:
|
|
wav = model.generate(
|
|
text,
|
|
audio_prompt_path=prompt_path,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
repetition_penalty=repetition_penalty,
|
|
)
|
|
except Exception as e:
|
|
print(f"Failed to generate audio: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
wav = enhance_audio(wav.cpu(), model.sr)
|
|
ta.save(output_path, wav, model.sr, format="wav")
|
|
print(f"TTS output saved to {output_path}")
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Chatterbox TTS CLI")
|
|
subparsers = parser.add_subparsers(dest="command")
|
|
|
|
# register base64
|
|
p_register_base64 = subparsers.add_parser("register-base64", help="Register prompt from base64 string")
|
|
p_register_base64.add_argument("prompt_name")
|
|
p_register_base64.add_argument("base64_audio")
|
|
|
|
# register file
|
|
p_register_file = subparsers.add_parser("register-file", help="Register prompt from wav file")
|
|
p_register_file.add_argument("src_path")
|
|
p_register_file.add_argument("--name", default=None)
|
|
|
|
# list prompt
|
|
p_list = subparsers.add_parser("list", help="List all prompt wav files")
|
|
|
|
# delete prompt
|
|
p_delete = subparsers.add_parser("delete", help="Delete prompt wav file")
|
|
p_delete.add_argument("prompt_name")
|
|
|
|
# rename prompt
|
|
p_rename = subparsers.add_parser("rename", help="Rename prompt wav file")
|
|
p_rename.add_argument("old_name")
|
|
p_rename.add_argument("new_name")
|
|
|
|
# tts generate
|
|
p_tts = subparsers.add_parser("tts", help="Generate TTS wav file")
|
|
p_tts.add_argument("text")
|
|
p_tts.add_argument("prompt")
|
|
p_tts.add_argument("--output", default="output.wav")
|
|
p_tts.add_argument("--temperature", type=float, default=TEMPERATURE)
|
|
p_tts.add_argument("--top_p", type=float, default=TOP_P)
|
|
p_tts.add_argument("--repetition_penalty", type=float, default=REPETITION_PENALTY)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.command == "register-base64":
|
|
register_prompt_base64(args.prompt_name, args.base64_audio)
|
|
elif args.command == "register-file":
|
|
register_prompt_file(args.src_path, args.name)
|
|
elif args.command == "list":
|
|
list_prompt()
|
|
elif args.command == "delete":
|
|
delete_prompt(args.prompt_name)
|
|
elif args.command == "rename":
|
|
rename_prompt(args.old_name, args.new_name)
|
|
elif args.command == "tts":
|
|
tts(args.text, args.prompt, args.output, args.temperature, args.top_p, args.repetition_penalty)
|
|
else:
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|