import io import os import sys import base64 import argparse import torch import torchaudio as ta import torchaudio.functional as F from chatterbox.tts import ChatterboxTTS from huggingface_hub import hf_hub_download from safetensors.torch import load_file # ========================= # KONFIGURASI MODEL - DISESUAIKAN UNTUK NATURALNESS # ========================= MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian" CHECKPOINT = "t3_cfg.safetensors" DEVICE = "cpu" # Parameter dioptimasi untuk suara lebih natural dan mirip source TEMPERATURE = 0.65 TOP_P = 0.88 REPETITION_PENALTY = 1.25 AUDIO_GAIN_DB = 0.8 PROMPT_FOLDER = "prompt_source" os.makedirs(PROMPT_FOLDER, exist_ok=True) # ========================= # Enhance audio dengan fokus pada naturalness # ========================= def enhance_audio(wav, sr): peak = wav.abs().max() if peak > 0: wav = wav / (peak + 1e-8) * 0.95 wav = F.highpass_biquad(wav, sr, cutoff_freq=60) wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000) wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7) wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7) threshold = 0.6 ratio = 2.5 knee = 0.1 abs_wav = wav.abs() mask_hard = abs_wav > (threshold + knee) mask_knee = (abs_wav > (threshold - knee)) & (abs_wav <= (threshold + knee)) compressed = torch.where( mask_hard, torch.sign(wav) * (threshold + (abs_wav - threshold) / ratio), wav ) knee_factor = ((abs_wav - (threshold - knee)) / (2 * knee)) ** 2 knee_compressed = torch.sign(wav) * ( threshold - knee + knee_factor * (2 * knee) + (abs_wav - threshold) / ratio * knee_factor ) compressed = torch.where(mask_knee, knee_compressed, compressed) wav = compressed saturation_amount = 0.08 wav = torch.tanh(wav * (1 + saturation_amount)) / (1 + saturation_amount) pink_noise = generate_pink_noise(wav.shape, wav.device) * 0.0003 wav = wav + pink_noise wav = torch.tanh(wav * 1.1) * 0.92 wav = F.gain(wav, gain_db=AUDIO_GAIN_DB) peak = wav.abs().max().item() if peak > 0: wav = wav / peak * 0.88 return wav def generate_pink_noise(shape, device): white = torch.randn(shape, device=device) if len(shape) > 1: pink = torch.zeros_like(white) for i in range(shape[0]): b = torch.zeros(7) for j in range(shape[1]): white_val = white[i, j].item() b[0] = 0.99886 * b[0] + white_val * 0.0555179 b[1] = 0.99332 * b[1] + white_val * 0.0750759 b[2] = 0.96900 * b[2] + white_val * 0.1538520 b[3] = 0.86650 * b[3] + white_val * 0.3104856 b[4] = 0.55000 * b[4] + white_val * 0.5329522 b[5] = -0.7616 * b[5] - white_val * 0.0168980 pink[i, j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11 b[6] = white_val * 0.115926 else: pink = torch.zeros_like(white) b = torch.zeros(7) for j in range(shape[0]): white_val = white[j].item() b[0] = 0.99886 * b[0] + white_val * 0.0555179 b[1] = 0.99332 * b[1] + white_val * 0.0750759 b[2] = 0.96900 * b[2] + white_val * 0.1538520 b[3] = 0.86650 * b[3] + white_val * 0.3104856 b[4] = 0.55000 * b[4] + white_val * 0.5329522 b[5] = -0.7616 * b[5] - white_val * 0.0168980 pink[j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11 b[6] = white_val * 0.115926 return pink * 0.1 # Load model sekali print("Loading model...") model = ChatterboxTTS.from_pretrained(device=DEVICE) ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT) state = load_file(ckpt, device=DEVICE) model.t3.to(DEVICE).load_state_dict(state) model.t3.eval() for module in model.t3.modules(): if hasattr(module, "training"): module.training = False for module in model.t3.modules(): if isinstance(module, torch.nn.Dropout): module.p = 0 print("Model ready with enhanced settings.") # ======= Fungsi CLI ======= def register_prompt_base64(prompt_name, base64_audio): filename = f"{prompt_name}.wav" path = os.path.join(PROMPT_FOLDER, filename) raw = base64.b64decode(base64_audio) with open(path, "wb") as f: f.write(raw) print(f"Registered base64 prompt as {filename}") def register_prompt_file(src_path, prompt_name=None): if prompt_name is None: prompt_name = os.path.splitext(os.path.basename(src_path))[0] save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav") with open(src_path, "rb") as src_file: data = src_file.read() with open(save_path, "wb") as dst_file: dst_file.write(data) print(f"Registered prompt file as {prompt_name}.wav") def list_prompt(): files = os.listdir(PROMPT_FOLDER) wav_files = [f for f in files if f.lower().endswith(".wav")] print(f"Total prompts: {len(wav_files)}") for f in wav_files: print(f"- {f}") def delete_prompt(prompt_name): file_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav") if not os.path.exists(file_path): print(f"Prompt '{prompt_name}' not found.", file=sys.stderr) sys.exit(1) os.remove(file_path) print(f"Deleted prompt {prompt_name}.wav") def rename_prompt(old_name, new_name): old_path = os.path.join(PROMPT_FOLDER, f"{old_name}.wav") new_path = os.path.join(PROMPT_FOLDER, f"{new_name}.wav") if not os.path.exists(old_path): print(f"Old prompt '{old_name}' not found.", file=sys.stderr) sys.exit(1) if os.path.exists(new_path): print(f"New prompt name '{new_name}' already exists.", file=sys.stderr) sys.exit(1) os.rename(old_path, new_path) print(f"Renamed prompt '{old_name}' to '{new_name}'") def tts(text, prompt, output_path="output.wav", temperature=TEMPERATURE, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY): prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav") if not os.path.exists(prompt_path): print(f"Prompt '{prompt}' not found.", file=sys.stderr) sys.exit(1) try: wav = model.generate( text, audio_prompt_path=prompt_path, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, ) except Exception as e: print(f"Failed to generate audio: {e}", file=sys.stderr) sys.exit(1) wav = enhance_audio(wav.cpu(), model.sr) ta.save(output_path, wav, model.sr, format="wav") print(f"TTS output saved to {output_path}") def main(): parser = argparse.ArgumentParser(description="Chatterbox TTS CLI") subparsers = parser.add_subparsers(dest="command") # register base64 p_register_base64 = subparsers.add_parser("register-base64", help="Register prompt from base64 string") p_register_base64.add_argument("prompt_name") p_register_base64.add_argument("base64_audio") # register file p_register_file = subparsers.add_parser("register-file", help="Register prompt from wav file") p_register_file.add_argument("src_path") p_register_file.add_argument("--name", default=None) # list prompt p_list = subparsers.add_parser("list", help="List all prompt wav files") # delete prompt p_delete = subparsers.add_parser("delete", help="Delete prompt wav file") p_delete.add_argument("prompt_name") # rename prompt p_rename = subparsers.add_parser("rename", help="Rename prompt wav file") p_rename.add_argument("old_name") p_rename.add_argument("new_name") # tts generate p_tts = subparsers.add_parser("tts", help="Generate TTS wav file") p_tts.add_argument("text") p_tts.add_argument("prompt") p_tts.add_argument("--output", default="output.wav") p_tts.add_argument("--temperature", type=float, default=TEMPERATURE) p_tts.add_argument("--top_p", type=float, default=TOP_P) p_tts.add_argument("--repetition_penalty", type=float, default=REPETITION_PENALTY) args = parser.parse_args() if args.command == "register-base64": register_prompt_base64(args.prompt_name, args.base64_audio) elif args.command == "register-file": register_prompt_file(args.src_path, args.name) elif args.command == "list": list_prompt() elif args.command == "delete": delete_prompt(args.prompt_name) elif args.command == "rename": rename_prompt(args.old_name, args.new_name) elif args.command == "tts": tts(args.text, args.prompt, args.output, args.temperature, args.top_p, args.repetition_penalty) else: parser.print_help() sys.exit(1) if __name__ == "__main__": main()