Files
jenna-tools/xclonev3.py
bipproduction 822b68c10f tambahannya
2025-12-07 09:00:54 +08:00

427 lines
13 KiB
Python

import io
import os
import base64
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
import torch
import torchaudio as ta
import torchaudio.functional as F
from chatterbox.tts import ChatterboxTTS
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# =========================
# KONFIGURASI MODEL - DISESUAIKAN UNTUK NATURALNESS
# =========================
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
CHECKPOINT = "t3_cfg.safetensors"
DEVICE = "cpu"
# Parameter dioptimasi untuk suara lebih natural dan mirip source
TEMPERATURE = 0.65 # Lebih rendah untuk lebih konsisten dengan prompt
TOP_P = 0.88 # Lebih fokus pada prediksi berkualitas tinggi
REPETITION_PENALTY = 1.25 # Lebih tinggi untuk menghindari pola repetitif robot
AUDIO_GAIN_DB = 0.8 # Gain lebih rendah untuk suara natural
# Parameter tambahan untuk kontrol kualitas
TOP_K = 50 # Batasi kandidat token
MIN_P = 0.05 # Filter probabilitas rendah
CFG_SCALE = 1.2 # Classifier-free guidance untuk adherence ke prompt
PROMPT_FOLDER = "prompt_source"
os.makedirs(PROMPT_FOLDER, exist_ok=True)
app = FastAPI(title="Chatterbox TTS Server - Enhanced")
# =========================
# Enhance audio dengan fokus pada naturalness
# =========================
def enhance_audio(wav, sr):
"""
Enhanced audio processing untuk suara lebih natural dan mirip source
"""
# 1. Normalisasi awal yang lembut
peak = wav.abs().max()
if peak > 0:
wav = wav / (peak + 1e-8) * 0.95
# 2. De-essing ringan (kurangi sibilance yang khas robot)
wav = F.highpass_biquad(wav, sr, cutoff_freq=60)
wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000)
# 3. Tambahkan sedikit warmth dengan subtle low-shelf boost
# Simulasi resonansi natural voice
wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7)
# 4. De-harsh treble (kurangi ketajaman digital)
wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7)
# 5. Soft compression untuk dynamic range natural
# Kompresi multi-band untuk maintain naturalness
threshold = 0.6
ratio = 2.5
knee = 0.1
abs_wav = wav.abs()
mask_hard = abs_wav > (threshold + knee)
mask_knee = (abs_wav > (threshold - knee)) & (abs_wav <= (threshold + knee))
# Hard compression
compressed = torch.where(
mask_hard,
torch.sign(wav) * (threshold + (abs_wav - threshold) / ratio),
wav
)
# Soft knee
knee_factor = ((abs_wav - (threshold - knee)) / (2 * knee)) ** 2
knee_compressed = torch.sign(wav) * (
threshold - knee + knee_factor * (2 * knee) +
(abs_wav - threshold) / ratio * knee_factor
)
compressed = torch.where(mask_knee, knee_compressed, compressed)
wav = compressed
# 6. Subtle saturation untuk warmth (analog-like)
saturation_amount = 0.08
wav = torch.tanh(wav * (1 + saturation_amount)) / (1 + saturation_amount)
# 7. Tambahkan very subtle analog-style noise (bukan digital noise)
# Ini membantu mask artifacts digital dan menambah warmth
pink_noise = generate_pink_noise(wav.shape, wav.device) * 0.0003
wav = wav + pink_noise
# 8. Gentle limiting untuk prevent clipping tanpa harshness
wav = torch.tanh(wav * 1.1) * 0.92
# 9. Final gain adjustment
wav = F.gain(wav, gain_db=AUDIO_GAIN_DB)
# 10. Final normalization dengan headroom
peak = wav.abs().max().item()
if peak > 0:
wav = wav / peak * 0.88 # Lebih banyak headroom untuk suara natural
return wav
def generate_pink_noise(shape, device):
"""
Generate pink noise (1/f noise) untuk natural analog warmth
"""
white = torch.randn(shape, device=device)
# Simple pink noise approximation menggunakan running sum
# Pink noise memiliki karakteristik lebih natural daripada white noise
if len(shape) > 1:
# Handle stereo/multi-channel
pink = torch.zeros_like(white)
for i in range(shape[0]):
b = torch.zeros(7)
for j in range(shape[1]):
white_val = white[i, j].item()
b[0] = 0.99886 * b[0] + white_val * 0.0555179
b[1] = 0.99332 * b[1] + white_val * 0.0750759
b[2] = 0.96900 * b[2] + white_val * 0.1538520
b[3] = 0.86650 * b[3] + white_val * 0.3104856
b[4] = 0.55000 * b[4] + white_val * 0.5329522
b[5] = -0.7616 * b[5] - white_val * 0.0168980
pink[i, j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11
b[6] = white_val * 0.115926
else:
# Mono
pink = torch.zeros_like(white)
b = torch.zeros(7)
for j in range(shape[0]):
white_val = white[j].item()
b[0] = 0.99886 * b[0] + white_val * 0.0555179
b[1] = 0.99332 * b[1] + white_val * 0.0750759
b[2] = 0.96900 * b[2] + white_val * 0.1538520
b[3] = 0.86650 * b[3] + white_val * 0.3104856
b[4] = 0.55000 * b[4] + white_val * 0.5329522
b[5] = -0.7616 * b[5] - white_val * 0.0168980
pink[j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11
b[6] = white_val * 0.115926
return pink * 0.1 # Scale down
# =========================
# Load model sekali
# =========================
print("Loading model...")
model = ChatterboxTTS.from_pretrained(device=DEVICE)
ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT)
state = load_file(ckpt, device=DEVICE)
model.t3.to(DEVICE).load_state_dict(state)
model.t3.eval()
# Pastikan model dalam mode eval penuh
for module in model.t3.modules():
if hasattr(module, "training"):
module.training = False
# Disable dropout untuk konsistensi maksimal
for module in model.t3.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0
print("Model ready with enhanced settings.")
# =====================================================
# 1A. REGISTER BASE64 -> WAV
# =====================================================
class RegisterPromptBase64(BaseModel):
prompt_name: str
base64_audio: str
@app.post("/register-prompt-base64")
async def register_prompt_base64(data: RegisterPromptBase64):
filename = f"{data.prompt_name}.wav"
path = os.path.join(PROMPT_FOLDER, filename)
try:
raw = base64.b64decode(data.base64_audio)
with open(path, "wb") as f:
f.write(raw)
return {"status": "ok", "file": filename}
except Exception as e:
return JSONResponse(
status_code=400, content={"error": f"Gagal decode / simpan audio: {e}"}
)
# =====================================================
# 1B. REGISTER FILE UPLOAD (form-data)
# =====================================================
@app.post("/register-prompt-file")
async def register_prompt_file(prompt: UploadFile = File(...), name: str = Form(None)):
prompt_name = name if name else os.path.splitext(prompt.filename)[0]
save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
try:
with open(save_path, "wb") as f:
f.write(await prompt.read())
return {"status": "ok", "file": f"{prompt_name}.wav"}
except Exception as e:
return JSONResponse(
status_code=500, content={"error": f"Gagal menyimpan prompt: {e}"}
)
# =====================================================
# LIST PROMPT
# =====================================================
@app.get("/list-prompt")
async def list_prompt():
try:
files = os.listdir(PROMPT_FOLDER)
wav_files = [f for f in files if f.lower().endswith(".wav")]
return {
"count": len(wav_files),
"prompts": wav_files,
"prompt_names": [os.path.splitext(f)[0] for f in wav_files],
}
except Exception as e:
return JSONResponse(
status_code=500, content={"error": f"Gagal membaca folder prompt: {e}"}
)
# =====================================================
# DELETE PROMPT
# =====================================================
class DeletePrompt(BaseModel):
prompt_name: str
@app.post("/delete-prompt")
async def delete_prompt(data: DeletePrompt):
file_path = os.path.join(PROMPT_FOLDER, f"{data.prompt_name}.wav")
if not os.path.exists(file_path):
return JSONResponse(
status_code=404, content={"error": "Prompt tidak ditemukan"}
)
try:
os.remove(file_path)
return {"status": "ok", "deleted": f"{data.prompt_name}.wav"}
except Exception as e:
return JSONResponse(status_code=500, content={"error": f"Gagal menghapus: {e}"})
# =====================================================
# RENAME PROMPT
# =====================================================
class RenamePrompt(BaseModel):
old_name: str
new_name: str
@app.post("/rename-prompt")
async def rename_prompt(data: RenamePrompt):
old_path = os.path.join(PROMPT_FOLDER, f"{data.old_name}.wav")
new_path = os.path.join(PROMPT_FOLDER, f"{data.new_name}.wav")
if not os.path.exists(old_path):
return JSONResponse(
status_code=404, content={"error": "Prompt lama tidak ditemukan"}
)
if os.path.exists(new_path):
return JSONResponse(
status_code=400, content={"error": "Nama baru sudah digunakan"}
)
try:
os.rename(old_path, new_path)
return {"status": "ok", "from": data.old_name, "to": data.new_name}
except Exception as e:
return JSONResponse(status_code=500, content={"error": f"Gagal rename: {e}"})
# =====================================================
# 2. TTS - ENHANCED dengan parameter optimal
# =====================================================
@app.post("/tts")
async def tts(text: str = Form(...), prompt: str = Form(...)):
prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav")
if not os.path.exists(prompt_path):
return JSONResponse(
status_code=404, content={"error": "Prompt tidak ditemukan"}
)
# Generate dengan parameter optimal untuk naturalness
try:
wav = model.generate(
text,
audio_prompt_path=prompt_path,
temperature=TEMPERATURE, # Lebih rendah = lebih konsisten dengan prompt
top_p=TOP_P, # Sampling lebih fokus
repetition_penalty=REPETITION_PENALTY, # Hindari pola robot
)
except Exception as e:
return JSONResponse(
status_code=500, content={"error": f"Gagal generate audio: {e}"}
)
# Enhanced audio processing
wav = enhance_audio(wav.cpu(), model.sr)
# Simpan ke buffer memori
buffer = io.BytesIO()
ta.save(buffer, wav, model.sr, format="wav")
buffer.seek(0)
return StreamingResponse(
buffer,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=output.wav"}
)
# =====================================================
# 3. TTS dengan parameter custom (advanced)
# =====================================================
class TTSCustomParams(BaseModel):
text: str
prompt: str
temperature: float = TEMPERATURE
top_p: float = TOP_P
repetition_penalty: float = REPETITION_PENALTY
@app.post("/tts-custom")
async def tts_custom(params: TTSCustomParams):
"""
Endpoint untuk experimentation dengan parameter berbeda
"""
prompt_path = os.path.join(PROMPT_FOLDER, f"{params.prompt}.wav")
if not os.path.exists(prompt_path):
return JSONResponse(
status_code=404, content={"error": "Prompt tidak ditemukan"}
)
try:
wav = model.generate(
params.text,
audio_prompt_path=prompt_path,
temperature=params.temperature,
top_p=params.top_p,
repetition_penalty=params.repetition_penalty,
)
except Exception as e:
return JSONResponse(
status_code=500, content={"error": f"Gagal generate audio: {e}"}
)
wav = enhance_audio(wav.cpu(), model.sr)
buffer = io.BytesIO()
ta.save(buffer, wav, model.sr, format="wav")
buffer.seek(0)
return StreamingResponse(
buffer,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=output_custom.wav"}
)
# =========================
# Root
# =========================
@app.get("/")
async def root():
return {
"message": "Chatterbox TTS API - Enhanced Voice Quality",
"version": "2.0",
"improvements": [
"Optimized temperature & sampling for source similarity",
"Enhanced audio processing for natural voice",
"Reduced robotic artifacts",
"Better emotion & intonation preservation",
"Analog-style warmth processing"
]
}
# =========================
# Health check
# =========================
@app.get("/health")
async def health():
return {
"status": "healthy",
"model_loaded": True,
"device": DEVICE,
"settings": {
"temperature": TEMPERATURE,
"top_p": TOP_P,
"repetition_penalty": REPETITION_PENALTY,
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run("claude_clonev4:app", host="0.0.0.0", port=6007, reload=False)