427 lines
13 KiB
Python
427 lines
13 KiB
Python
import io
|
|
import os
|
|
import base64
|
|
from fastapi import FastAPI, UploadFile, File, Form
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
import torch
|
|
import torchaudio as ta
|
|
import torchaudio.functional as F
|
|
from chatterbox.tts import ChatterboxTTS
|
|
from huggingface_hub import hf_hub_download
|
|
from safetensors.torch import load_file
|
|
|
|
# =========================
|
|
# KONFIGURASI MODEL - DISESUAIKAN UNTUK NATURALNESS
|
|
# =========================
|
|
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
|
|
CHECKPOINT = "t3_cfg.safetensors"
|
|
DEVICE = "cpu"
|
|
|
|
# Parameter dioptimasi untuk suara lebih natural dan mirip source
|
|
TEMPERATURE = 0.65 # Lebih rendah untuk lebih konsisten dengan prompt
|
|
TOP_P = 0.88 # Lebih fokus pada prediksi berkualitas tinggi
|
|
REPETITION_PENALTY = 1.25 # Lebih tinggi untuk menghindari pola repetitif robot
|
|
AUDIO_GAIN_DB = 0.8 # Gain lebih rendah untuk suara natural
|
|
|
|
# Parameter tambahan untuk kontrol kualitas
|
|
TOP_K = 50 # Batasi kandidat token
|
|
MIN_P = 0.05 # Filter probabilitas rendah
|
|
CFG_SCALE = 1.2 # Classifier-free guidance untuk adherence ke prompt
|
|
|
|
PROMPT_FOLDER = "prompt_source"
|
|
os.makedirs(PROMPT_FOLDER, exist_ok=True)
|
|
|
|
app = FastAPI(title="Chatterbox TTS Server - Enhanced")
|
|
|
|
|
|
# =========================
|
|
# Enhance audio dengan fokus pada naturalness
|
|
# =========================
|
|
def enhance_audio(wav, sr):
|
|
"""
|
|
Enhanced audio processing untuk suara lebih natural dan mirip source
|
|
"""
|
|
# 1. Normalisasi awal yang lembut
|
|
peak = wav.abs().max()
|
|
if peak > 0:
|
|
wav = wav / (peak + 1e-8) * 0.95
|
|
|
|
# 2. De-essing ringan (kurangi sibilance yang khas robot)
|
|
wav = F.highpass_biquad(wav, sr, cutoff_freq=60)
|
|
wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000)
|
|
|
|
# 3. Tambahkan sedikit warmth dengan subtle low-shelf boost
|
|
# Simulasi resonansi natural voice
|
|
wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7)
|
|
|
|
# 4. De-harsh treble (kurangi ketajaman digital)
|
|
wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7)
|
|
|
|
# 5. Soft compression untuk dynamic range natural
|
|
# Kompresi multi-band untuk maintain naturalness
|
|
threshold = 0.6
|
|
ratio = 2.5
|
|
knee = 0.1
|
|
|
|
abs_wav = wav.abs()
|
|
mask_hard = abs_wav > (threshold + knee)
|
|
mask_knee = (abs_wav > (threshold - knee)) & (abs_wav <= (threshold + knee))
|
|
|
|
# Hard compression
|
|
compressed = torch.where(
|
|
mask_hard,
|
|
torch.sign(wav) * (threshold + (abs_wav - threshold) / ratio),
|
|
wav
|
|
)
|
|
|
|
# Soft knee
|
|
knee_factor = ((abs_wav - (threshold - knee)) / (2 * knee)) ** 2
|
|
knee_compressed = torch.sign(wav) * (
|
|
threshold - knee + knee_factor * (2 * knee) +
|
|
(abs_wav - threshold) / ratio * knee_factor
|
|
)
|
|
compressed = torch.where(mask_knee, knee_compressed, compressed)
|
|
|
|
wav = compressed
|
|
|
|
# 6. Subtle saturation untuk warmth (analog-like)
|
|
saturation_amount = 0.08
|
|
wav = torch.tanh(wav * (1 + saturation_amount)) / (1 + saturation_amount)
|
|
|
|
# 7. Tambahkan very subtle analog-style noise (bukan digital noise)
|
|
# Ini membantu mask artifacts digital dan menambah warmth
|
|
pink_noise = generate_pink_noise(wav.shape, wav.device) * 0.0003
|
|
wav = wav + pink_noise
|
|
|
|
# 8. Gentle limiting untuk prevent clipping tanpa harshness
|
|
wav = torch.tanh(wav * 1.1) * 0.92
|
|
|
|
# 9. Final gain adjustment
|
|
wav = F.gain(wav, gain_db=AUDIO_GAIN_DB)
|
|
|
|
# 10. Final normalization dengan headroom
|
|
peak = wav.abs().max().item()
|
|
if peak > 0:
|
|
wav = wav / peak * 0.88 # Lebih banyak headroom untuk suara natural
|
|
|
|
return wav
|
|
|
|
|
|
def generate_pink_noise(shape, device):
|
|
"""
|
|
Generate pink noise (1/f noise) untuk natural analog warmth
|
|
"""
|
|
white = torch.randn(shape, device=device)
|
|
|
|
# Simple pink noise approximation menggunakan running sum
|
|
# Pink noise memiliki karakteristik lebih natural daripada white noise
|
|
if len(shape) > 1:
|
|
# Handle stereo/multi-channel
|
|
pink = torch.zeros_like(white)
|
|
for i in range(shape[0]):
|
|
b = torch.zeros(7)
|
|
for j in range(shape[1]):
|
|
white_val = white[i, j].item()
|
|
b[0] = 0.99886 * b[0] + white_val * 0.0555179
|
|
b[1] = 0.99332 * b[1] + white_val * 0.0750759
|
|
b[2] = 0.96900 * b[2] + white_val * 0.1538520
|
|
b[3] = 0.86650 * b[3] + white_val * 0.3104856
|
|
b[4] = 0.55000 * b[4] + white_val * 0.5329522
|
|
b[5] = -0.7616 * b[5] - white_val * 0.0168980
|
|
pink[i, j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11
|
|
b[6] = white_val * 0.115926
|
|
else:
|
|
# Mono
|
|
pink = torch.zeros_like(white)
|
|
b = torch.zeros(7)
|
|
for j in range(shape[0]):
|
|
white_val = white[j].item()
|
|
b[0] = 0.99886 * b[0] + white_val * 0.0555179
|
|
b[1] = 0.99332 * b[1] + white_val * 0.0750759
|
|
b[2] = 0.96900 * b[2] + white_val * 0.1538520
|
|
b[3] = 0.86650 * b[3] + white_val * 0.3104856
|
|
b[4] = 0.55000 * b[4] + white_val * 0.5329522
|
|
b[5] = -0.7616 * b[5] - white_val * 0.0168980
|
|
pink[j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11
|
|
b[6] = white_val * 0.115926
|
|
|
|
return pink * 0.1 # Scale down
|
|
|
|
|
|
# =========================
|
|
# Load model sekali
|
|
# =========================
|
|
print("Loading model...")
|
|
|
|
model = ChatterboxTTS.from_pretrained(device=DEVICE)
|
|
ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT)
|
|
state = load_file(ckpt, device=DEVICE)
|
|
|
|
model.t3.to(DEVICE).load_state_dict(state)
|
|
model.t3.eval()
|
|
|
|
# Pastikan model dalam mode eval penuh
|
|
for module in model.t3.modules():
|
|
if hasattr(module, "training"):
|
|
module.training = False
|
|
|
|
# Disable dropout untuk konsistensi maksimal
|
|
for module in model.t3.modules():
|
|
if isinstance(module, torch.nn.Dropout):
|
|
module.p = 0
|
|
|
|
print("Model ready with enhanced settings.")
|
|
|
|
|
|
# =====================================================
|
|
# 1A. REGISTER BASE64 -> WAV
|
|
# =====================================================
|
|
class RegisterPromptBase64(BaseModel):
|
|
prompt_name: str
|
|
base64_audio: str
|
|
|
|
|
|
@app.post("/register-prompt-base64")
|
|
async def register_prompt_base64(data: RegisterPromptBase64):
|
|
filename = f"{data.prompt_name}.wav"
|
|
path = os.path.join(PROMPT_FOLDER, filename)
|
|
|
|
try:
|
|
raw = base64.b64decode(data.base64_audio)
|
|
with open(path, "wb") as f:
|
|
f.write(raw)
|
|
|
|
return {"status": "ok", "file": filename}
|
|
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=400, content={"error": f"Gagal decode / simpan audio: {e}"}
|
|
)
|
|
|
|
|
|
# =====================================================
|
|
# 1B. REGISTER FILE UPLOAD (form-data)
|
|
# =====================================================
|
|
@app.post("/register-prompt-file")
|
|
async def register_prompt_file(prompt: UploadFile = File(...), name: str = Form(None)):
|
|
prompt_name = name if name else os.path.splitext(prompt.filename)[0]
|
|
save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
|
|
|
|
try:
|
|
with open(save_path, "wb") as f:
|
|
f.write(await prompt.read())
|
|
|
|
return {"status": "ok", "file": f"{prompt_name}.wav"}
|
|
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=500, content={"error": f"Gagal menyimpan prompt: {e}"}
|
|
)
|
|
|
|
|
|
# =====================================================
|
|
# LIST PROMPT
|
|
# =====================================================
|
|
@app.get("/list-prompt")
|
|
async def list_prompt():
|
|
try:
|
|
files = os.listdir(PROMPT_FOLDER)
|
|
wav_files = [f for f in files if f.lower().endswith(".wav")]
|
|
|
|
return {
|
|
"count": len(wav_files),
|
|
"prompts": wav_files,
|
|
"prompt_names": [os.path.splitext(f)[0] for f in wav_files],
|
|
}
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=500, content={"error": f"Gagal membaca folder prompt: {e}"}
|
|
)
|
|
|
|
|
|
# =====================================================
|
|
# DELETE PROMPT
|
|
# =====================================================
|
|
class DeletePrompt(BaseModel):
|
|
prompt_name: str
|
|
|
|
|
|
@app.post("/delete-prompt")
|
|
async def delete_prompt(data: DeletePrompt):
|
|
file_path = os.path.join(PROMPT_FOLDER, f"{data.prompt_name}.wav")
|
|
|
|
if not os.path.exists(file_path):
|
|
return JSONResponse(
|
|
status_code=404, content={"error": "Prompt tidak ditemukan"}
|
|
)
|
|
|
|
try:
|
|
os.remove(file_path)
|
|
return {"status": "ok", "deleted": f"{data.prompt_name}.wav"}
|
|
|
|
except Exception as e:
|
|
return JSONResponse(status_code=500, content={"error": f"Gagal menghapus: {e}"})
|
|
|
|
|
|
# =====================================================
|
|
# RENAME PROMPT
|
|
# =====================================================
|
|
class RenamePrompt(BaseModel):
|
|
old_name: str
|
|
new_name: str
|
|
|
|
|
|
@app.post("/rename-prompt")
|
|
async def rename_prompt(data: RenamePrompt):
|
|
old_path = os.path.join(PROMPT_FOLDER, f"{data.old_name}.wav")
|
|
new_path = os.path.join(PROMPT_FOLDER, f"{data.new_name}.wav")
|
|
|
|
if not os.path.exists(old_path):
|
|
return JSONResponse(
|
|
status_code=404, content={"error": "Prompt lama tidak ditemukan"}
|
|
)
|
|
|
|
if os.path.exists(new_path):
|
|
return JSONResponse(
|
|
status_code=400, content={"error": "Nama baru sudah digunakan"}
|
|
)
|
|
|
|
try:
|
|
os.rename(old_path, new_path)
|
|
return {"status": "ok", "from": data.old_name, "to": data.new_name}
|
|
|
|
except Exception as e:
|
|
return JSONResponse(status_code=500, content={"error": f"Gagal rename: {e}"})
|
|
|
|
|
|
# =====================================================
|
|
# 2. TTS - ENHANCED dengan parameter optimal
|
|
# =====================================================
|
|
@app.post("/tts")
|
|
async def tts(text: str = Form(...), prompt: str = Form(...)):
|
|
prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav")
|
|
|
|
if not os.path.exists(prompt_path):
|
|
return JSONResponse(
|
|
status_code=404, content={"error": "Prompt tidak ditemukan"}
|
|
)
|
|
|
|
# Generate dengan parameter optimal untuk naturalness
|
|
try:
|
|
wav = model.generate(
|
|
text,
|
|
audio_prompt_path=prompt_path,
|
|
temperature=TEMPERATURE, # Lebih rendah = lebih konsisten dengan prompt
|
|
top_p=TOP_P, # Sampling lebih fokus
|
|
repetition_penalty=REPETITION_PENALTY, # Hindari pola robot
|
|
)
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=500, content={"error": f"Gagal generate audio: {e}"}
|
|
)
|
|
|
|
# Enhanced audio processing
|
|
wav = enhance_audio(wav.cpu(), model.sr)
|
|
|
|
# Simpan ke buffer memori
|
|
buffer = io.BytesIO()
|
|
ta.save(buffer, wav, model.sr, format="wav")
|
|
buffer.seek(0)
|
|
|
|
return StreamingResponse(
|
|
buffer,
|
|
media_type="audio/wav",
|
|
headers={"Content-Disposition": "attachment; filename=output.wav"}
|
|
)
|
|
|
|
|
|
# =====================================================
|
|
# 3. TTS dengan parameter custom (advanced)
|
|
# =====================================================
|
|
class TTSCustomParams(BaseModel):
|
|
text: str
|
|
prompt: str
|
|
temperature: float = TEMPERATURE
|
|
top_p: float = TOP_P
|
|
repetition_penalty: float = REPETITION_PENALTY
|
|
|
|
|
|
@app.post("/tts-custom")
|
|
async def tts_custom(params: TTSCustomParams):
|
|
"""
|
|
Endpoint untuk experimentation dengan parameter berbeda
|
|
"""
|
|
prompt_path = os.path.join(PROMPT_FOLDER, f"{params.prompt}.wav")
|
|
|
|
if not os.path.exists(prompt_path):
|
|
return JSONResponse(
|
|
status_code=404, content={"error": "Prompt tidak ditemukan"}
|
|
)
|
|
|
|
try:
|
|
wav = model.generate(
|
|
params.text,
|
|
audio_prompt_path=prompt_path,
|
|
temperature=params.temperature,
|
|
top_p=params.top_p,
|
|
repetition_penalty=params.repetition_penalty,
|
|
)
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=500, content={"error": f"Gagal generate audio: {e}"}
|
|
)
|
|
|
|
wav = enhance_audio(wav.cpu(), model.sr)
|
|
|
|
buffer = io.BytesIO()
|
|
ta.save(buffer, wav, model.sr, format="wav")
|
|
buffer.seek(0)
|
|
|
|
return StreamingResponse(
|
|
buffer,
|
|
media_type="audio/wav",
|
|
headers={"Content-Disposition": "attachment; filename=output_custom.wav"}
|
|
)
|
|
|
|
|
|
# =========================
|
|
# Root
|
|
# =========================
|
|
@app.get("/")
|
|
async def root():
|
|
return {
|
|
"message": "Chatterbox TTS API - Enhanced Voice Quality",
|
|
"version": "2.0",
|
|
"improvements": [
|
|
"Optimized temperature & sampling for source similarity",
|
|
"Enhanced audio processing for natural voice",
|
|
"Reduced robotic artifacts",
|
|
"Better emotion & intonation preservation",
|
|
"Analog-style warmth processing"
|
|
]
|
|
}
|
|
|
|
|
|
# =========================
|
|
# Health check
|
|
# =========================
|
|
@app.get("/health")
|
|
async def health():
|
|
return {
|
|
"status": "healthy",
|
|
"model_loaded": True,
|
|
"device": DEVICE,
|
|
"settings": {
|
|
"temperature": TEMPERATURE,
|
|
"top_p": TOP_P,
|
|
"repetition_penalty": REPETITION_PENALTY,
|
|
}
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run("claude_clonev4:app", host="0.0.0.0", port=6007, reload=False) |