Files
jenna-tools/xclone.py
bipproduction 822b68c10f tambahannya
2025-12-07 09:00:54 +08:00

246 lines
8.7 KiB
Python

import io
import os
import sys
import base64
import argparse
import torch
import torchaudio as ta
import torchaudio.functional as F
from chatterbox.tts import ChatterboxTTS
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# =========================
# KONFIGURASI MODEL - DISESUAIKAN UNTUK NATURALNESS
# =========================
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
CHECKPOINT = "t3_cfg.safetensors"
DEVICE = "cpu"
# Parameter dioptimasi untuk suara lebih natural dan mirip source
TEMPERATURE = 0.65
TOP_P = 0.88
REPETITION_PENALTY = 1.25
AUDIO_GAIN_DB = 0.8
PROMPT_FOLDER = "prompt_source"
os.makedirs(PROMPT_FOLDER, exist_ok=True)
# =========================
# Enhance audio dengan fokus pada naturalness
# =========================
def enhance_audio(wav, sr):
peak = wav.abs().max()
if peak > 0:
wav = wav / (peak + 1e-8) * 0.95
wav = F.highpass_biquad(wav, sr, cutoff_freq=60)
wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000)
wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7)
wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7)
threshold = 0.6
ratio = 2.5
knee = 0.1
abs_wav = wav.abs()
mask_hard = abs_wav > (threshold + knee)
mask_knee = (abs_wav > (threshold - knee)) & (abs_wav <= (threshold + knee))
compressed = torch.where(
mask_hard,
torch.sign(wav) * (threshold + (abs_wav - threshold) / ratio),
wav
)
knee_factor = ((abs_wav - (threshold - knee)) / (2 * knee)) ** 2
knee_compressed = torch.sign(wav) * (
threshold - knee + knee_factor * (2 * knee) + (abs_wav - threshold) / ratio * knee_factor
)
compressed = torch.where(mask_knee, knee_compressed, compressed)
wav = compressed
saturation_amount = 0.08
wav = torch.tanh(wav * (1 + saturation_amount)) / (1 + saturation_amount)
pink_noise = generate_pink_noise(wav.shape, wav.device) * 0.0003
wav = wav + pink_noise
wav = torch.tanh(wav * 1.1) * 0.92
wav = F.gain(wav, gain_db=AUDIO_GAIN_DB)
peak = wav.abs().max().item()
if peak > 0:
wav = wav / peak * 0.88
return wav
def generate_pink_noise(shape, device):
white = torch.randn(shape, device=device)
if len(shape) > 1:
pink = torch.zeros_like(white)
for i in range(shape[0]):
b = torch.zeros(7)
for j in range(shape[1]):
white_val = white[i, j].item()
b[0] = 0.99886 * b[0] + white_val * 0.0555179
b[1] = 0.99332 * b[1] + white_val * 0.0750759
b[2] = 0.96900 * b[2] + white_val * 0.1538520
b[3] = 0.86650 * b[3] + white_val * 0.3104856
b[4] = 0.55000 * b[4] + white_val * 0.5329522
b[5] = -0.7616 * b[5] - white_val * 0.0168980
pink[i, j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11
b[6] = white_val * 0.115926
else:
pink = torch.zeros_like(white)
b = torch.zeros(7)
for j in range(shape[0]):
white_val = white[j].item()
b[0] = 0.99886 * b[0] + white_val * 0.0555179
b[1] = 0.99332 * b[1] + white_val * 0.0750759
b[2] = 0.96900 * b[2] + white_val * 0.1538520
b[3] = 0.86650 * b[3] + white_val * 0.3104856
b[4] = 0.55000 * b[4] + white_val * 0.5329522
b[5] = -0.7616 * b[5] - white_val * 0.0168980
pink[j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11
b[6] = white_val * 0.115926
return pink * 0.1
# Load model sekali
print("Loading model...")
model = ChatterboxTTS.from_pretrained(device=DEVICE)
ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT)
state = load_file(ckpt, device=DEVICE)
model.t3.to(DEVICE).load_state_dict(state)
model.t3.eval()
for module in model.t3.modules():
if hasattr(module, "training"):
module.training = False
for module in model.t3.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0
print("Model ready with enhanced settings.")
# ======= Fungsi CLI =======
def register_prompt_base64(prompt_name, base64_audio):
filename = f"{prompt_name}.wav"
path = os.path.join(PROMPT_FOLDER, filename)
raw = base64.b64decode(base64_audio)
with open(path, "wb") as f:
f.write(raw)
print(f"Registered base64 prompt as {filename}")
def register_prompt_file(src_path, prompt_name=None):
if prompt_name is None:
prompt_name = os.path.splitext(os.path.basename(src_path))[0]
save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
with open(src_path, "rb") as src_file:
data = src_file.read()
with open(save_path, "wb") as dst_file:
dst_file.write(data)
print(f"Registered prompt file as {prompt_name}.wav")
def list_prompt():
files = os.listdir(PROMPT_FOLDER)
wav_files = [f for f in files if f.lower().endswith(".wav")]
print(f"Total prompts: {len(wav_files)}")
for f in wav_files:
print(f"- {f}")
def delete_prompt(prompt_name):
file_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
if not os.path.exists(file_path):
print(f"Prompt '{prompt_name}' not found.", file=sys.stderr)
sys.exit(1)
os.remove(file_path)
print(f"Deleted prompt {prompt_name}.wav")
def rename_prompt(old_name, new_name):
old_path = os.path.join(PROMPT_FOLDER, f"{old_name}.wav")
new_path = os.path.join(PROMPT_FOLDER, f"{new_name}.wav")
if not os.path.exists(old_path):
print(f"Old prompt '{old_name}' not found.", file=sys.stderr)
sys.exit(1)
if os.path.exists(new_path):
print(f"New prompt name '{new_name}' already exists.", file=sys.stderr)
sys.exit(1)
os.rename(old_path, new_path)
print(f"Renamed prompt '{old_name}' to '{new_name}'")
def tts(text, prompt, output_path="output.wav", temperature=TEMPERATURE, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY):
prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav")
if not os.path.exists(prompt_path):
print(f"Prompt '{prompt}' not found.", file=sys.stderr)
sys.exit(1)
try:
wav = model.generate(
text,
audio_prompt_path=prompt_path,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
except Exception as e:
print(f"Failed to generate audio: {e}", file=sys.stderr)
sys.exit(1)
wav = enhance_audio(wav.cpu(), model.sr)
ta.save(output_path, wav, model.sr, format="wav")
print(f"TTS output saved to {output_path}")
def main():
parser = argparse.ArgumentParser(description="Chatterbox TTS CLI")
subparsers = parser.add_subparsers(dest="command")
# register base64
p_register_base64 = subparsers.add_parser("register-base64", help="Register prompt from base64 string")
p_register_base64.add_argument("prompt_name")
p_register_base64.add_argument("base64_audio")
# register file
p_register_file = subparsers.add_parser("register-file", help="Register prompt from wav file")
p_register_file.add_argument("src_path")
p_register_file.add_argument("--name", default=None)
# list prompt
p_list = subparsers.add_parser("list", help="List all prompt wav files")
# delete prompt
p_delete = subparsers.add_parser("delete", help="Delete prompt wav file")
p_delete.add_argument("prompt_name")
# rename prompt
p_rename = subparsers.add_parser("rename", help="Rename prompt wav file")
p_rename.add_argument("old_name")
p_rename.add_argument("new_name")
# tts generate
p_tts = subparsers.add_parser("tts", help="Generate TTS wav file")
p_tts.add_argument("text")
p_tts.add_argument("prompt")
p_tts.add_argument("--output", default="output.wav")
p_tts.add_argument("--temperature", type=float, default=TEMPERATURE)
p_tts.add_argument("--top_p", type=float, default=TOP_P)
p_tts.add_argument("--repetition_penalty", type=float, default=REPETITION_PENALTY)
args = parser.parse_args()
if args.command == "register-base64":
register_prompt_base64(args.prompt_name, args.base64_audio)
elif args.command == "register-file":
register_prompt_file(args.src_path, args.name)
elif args.command == "list":
list_prompt()
elif args.command == "delete":
delete_prompt(args.prompt_name)
elif args.command == "rename":
rename_prompt(args.old_name, args.new_name)
elif args.command == "tts":
tts(args.text, args.prompt, args.output, args.temperature, args.top_p, args.repetition_penalty)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()