import io import os import base64 from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel import torch import torchaudio as ta import torchaudio.functional as F from chatterbox.tts import ChatterboxTTS from huggingface_hub import hf_hub_download from safetensors.torch import load_file # ========================= # KONFIGURASI MODEL - DISESUAIKAN UNTUK NATURALNESS # ========================= MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian" CHECKPOINT = "t3_cfg.safetensors" DEVICE = "cpu" # Parameter dioptimasi untuk suara lebih natural dan mirip source TEMPERATURE = 0.65 # Lebih rendah untuk lebih konsisten dengan prompt TOP_P = 0.88 # Lebih fokus pada prediksi berkualitas tinggi REPETITION_PENALTY = 1.25 # Lebih tinggi untuk menghindari pola repetitif robot AUDIO_GAIN_DB = 0.8 # Gain lebih rendah untuk suara natural # Parameter tambahan untuk kontrol kualitas TOP_K = 50 # Batasi kandidat token MIN_P = 0.05 # Filter probabilitas rendah CFG_SCALE = 1.2 # Classifier-free guidance untuk adherence ke prompt PROMPT_FOLDER = "prompt_source" os.makedirs(PROMPT_FOLDER, exist_ok=True) app = FastAPI(title="Chatterbox TTS Server - Enhanced") # ========================= # Enhance audio dengan fokus pada naturalness # ========================= def enhance_audio(wav, sr): """ Enhanced audio processing untuk suara lebih natural dan mirip source """ # 1. Normalisasi awal yang lembut peak = wav.abs().max() if peak > 0: wav = wav / (peak + 1e-8) * 0.95 # 2. De-essing ringan (kurangi sibilance yang khas robot) wav = F.highpass_biquad(wav, sr, cutoff_freq=60) wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000) # 3. Tambahkan sedikit warmth dengan subtle low-shelf boost # Simulasi resonansi natural voice wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7) # 4. De-harsh treble (kurangi ketajaman digital) wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7) # 5. Soft compression untuk dynamic range natural # Kompresi multi-band untuk maintain naturalness threshold = 0.6 ratio = 2.5 knee = 0.1 abs_wav = wav.abs() mask_hard = abs_wav > (threshold + knee) mask_knee = (abs_wav > (threshold - knee)) & (abs_wav <= (threshold + knee)) # Hard compression compressed = torch.where( mask_hard, torch.sign(wav) * (threshold + (abs_wav - threshold) / ratio), wav ) # Soft knee knee_factor = ((abs_wav - (threshold - knee)) / (2 * knee)) ** 2 knee_compressed = torch.sign(wav) * ( threshold - knee + knee_factor * (2 * knee) + (abs_wav - threshold) / ratio * knee_factor ) compressed = torch.where(mask_knee, knee_compressed, compressed) wav = compressed # 6. Subtle saturation untuk warmth (analog-like) saturation_amount = 0.08 wav = torch.tanh(wav * (1 + saturation_amount)) / (1 + saturation_amount) # 7. Tambahkan very subtle analog-style noise (bukan digital noise) # Ini membantu mask artifacts digital dan menambah warmth pink_noise = generate_pink_noise(wav.shape, wav.device) * 0.0003 wav = wav + pink_noise # 8. Gentle limiting untuk prevent clipping tanpa harshness wav = torch.tanh(wav * 1.1) * 0.92 # 9. Final gain adjustment wav = F.gain(wav, gain_db=AUDIO_GAIN_DB) # 10. Final normalization dengan headroom peak = wav.abs().max().item() if peak > 0: wav = wav / peak * 0.88 # Lebih banyak headroom untuk suara natural return wav def generate_pink_noise(shape, device): """ Generate pink noise (1/f noise) untuk natural analog warmth """ white = torch.randn(shape, device=device) # Simple pink noise approximation menggunakan running sum # Pink noise memiliki karakteristik lebih natural daripada white noise if len(shape) > 1: # Handle stereo/multi-channel pink = torch.zeros_like(white) for i in range(shape[0]): b = torch.zeros(7) for j in range(shape[1]): white_val = white[i, j].item() b[0] = 0.99886 * b[0] + white_val * 0.0555179 b[1] = 0.99332 * b[1] + white_val * 0.0750759 b[2] = 0.96900 * b[2] + white_val * 0.1538520 b[3] = 0.86650 * b[3] + white_val * 0.3104856 b[4] = 0.55000 * b[4] + white_val * 0.5329522 b[5] = -0.7616 * b[5] - white_val * 0.0168980 pink[i, j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11 b[6] = white_val * 0.115926 else: # Mono pink = torch.zeros_like(white) b = torch.zeros(7) for j in range(shape[0]): white_val = white[j].item() b[0] = 0.99886 * b[0] + white_val * 0.0555179 b[1] = 0.99332 * b[1] + white_val * 0.0750759 b[2] = 0.96900 * b[2] + white_val * 0.1538520 b[3] = 0.86650 * b[3] + white_val * 0.3104856 b[4] = 0.55000 * b[4] + white_val * 0.5329522 b[5] = -0.7616 * b[5] - white_val * 0.0168980 pink[j] = (b[0] + b[1] + b[2] + b[3] + b[4] + b[5] + b[6] + white_val * 0.5362) * 0.11 b[6] = white_val * 0.115926 return pink * 0.1 # Scale down # ========================= # Load model sekali # ========================= print("Loading model...") model = ChatterboxTTS.from_pretrained(device=DEVICE) ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT) state = load_file(ckpt, device=DEVICE) model.t3.to(DEVICE).load_state_dict(state) model.t3.eval() # Pastikan model dalam mode eval penuh for module in model.t3.modules(): if hasattr(module, "training"): module.training = False # Disable dropout untuk konsistensi maksimal for module in model.t3.modules(): if isinstance(module, torch.nn.Dropout): module.p = 0 print("Model ready with enhanced settings.") # ===================================================== # 1A. REGISTER BASE64 -> WAV # ===================================================== class RegisterPromptBase64(BaseModel): prompt_name: str base64_audio: str @app.post("/register-prompt-base64") async def register_prompt_base64(data: RegisterPromptBase64): filename = f"{data.prompt_name}.wav" path = os.path.join(PROMPT_FOLDER, filename) try: raw = base64.b64decode(data.base64_audio) with open(path, "wb") as f: f.write(raw) return {"status": "ok", "file": filename} except Exception as e: return JSONResponse( status_code=400, content={"error": f"Gagal decode / simpan audio: {e}"} ) # ===================================================== # 1B. REGISTER FILE UPLOAD (form-data) # ===================================================== @app.post("/register-prompt-file") async def register_prompt_file(prompt: UploadFile = File(...), name: str = Form(None)): prompt_name = name if name else os.path.splitext(prompt.filename)[0] save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav") try: with open(save_path, "wb") as f: f.write(await prompt.read()) return {"status": "ok", "file": f"{prompt_name}.wav"} except Exception as e: return JSONResponse( status_code=500, content={"error": f"Gagal menyimpan prompt: {e}"} ) # ===================================================== # LIST PROMPT # ===================================================== @app.get("/list-prompt") async def list_prompt(): try: files = os.listdir(PROMPT_FOLDER) wav_files = [f for f in files if f.lower().endswith(".wav")] return { "count": len(wav_files), "prompts": wav_files, "prompt_names": [os.path.splitext(f)[0] for f in wav_files], } except Exception as e: return JSONResponse( status_code=500, content={"error": f"Gagal membaca folder prompt: {e}"} ) # ===================================================== # DELETE PROMPT # ===================================================== class DeletePrompt(BaseModel): prompt_name: str @app.post("/delete-prompt") async def delete_prompt(data: DeletePrompt): file_path = os.path.join(PROMPT_FOLDER, f"{data.prompt_name}.wav") if not os.path.exists(file_path): return JSONResponse( status_code=404, content={"error": "Prompt tidak ditemukan"} ) try: os.remove(file_path) return {"status": "ok", "deleted": f"{data.prompt_name}.wav"} except Exception as e: return JSONResponse(status_code=500, content={"error": f"Gagal menghapus: {e}"}) # ===================================================== # RENAME PROMPT # ===================================================== class RenamePrompt(BaseModel): old_name: str new_name: str @app.post("/rename-prompt") async def rename_prompt(data: RenamePrompt): old_path = os.path.join(PROMPT_FOLDER, f"{data.old_name}.wav") new_path = os.path.join(PROMPT_FOLDER, f"{data.new_name}.wav") if not os.path.exists(old_path): return JSONResponse( status_code=404, content={"error": "Prompt lama tidak ditemukan"} ) if os.path.exists(new_path): return JSONResponse( status_code=400, content={"error": "Nama baru sudah digunakan"} ) try: os.rename(old_path, new_path) return {"status": "ok", "from": data.old_name, "to": data.new_name} except Exception as e: return JSONResponse(status_code=500, content={"error": f"Gagal rename: {e}"}) # ===================================================== # 2. TTS - ENHANCED dengan parameter optimal # ===================================================== @app.post("/tts") async def tts(text: str = Form(...), prompt: str = Form(...)): prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav") if not os.path.exists(prompt_path): return JSONResponse( status_code=404, content={"error": "Prompt tidak ditemukan"} ) # Generate dengan parameter optimal untuk naturalness try: wav = model.generate( text, audio_prompt_path=prompt_path, temperature=TEMPERATURE, # Lebih rendah = lebih konsisten dengan prompt top_p=TOP_P, # Sampling lebih fokus repetition_penalty=REPETITION_PENALTY, # Hindari pola robot ) except Exception as e: return JSONResponse( status_code=500, content={"error": f"Gagal generate audio: {e}"} ) # Enhanced audio processing wav = enhance_audio(wav.cpu(), model.sr) # Simpan ke buffer memori buffer = io.BytesIO() ta.save(buffer, wav, model.sr, format="wav") buffer.seek(0) return StreamingResponse( buffer, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=output.wav"} ) # ===================================================== # 3. TTS dengan parameter custom (advanced) # ===================================================== class TTSCustomParams(BaseModel): text: str prompt: str temperature: float = TEMPERATURE top_p: float = TOP_P repetition_penalty: float = REPETITION_PENALTY @app.post("/tts-custom") async def tts_custom(params: TTSCustomParams): """ Endpoint untuk experimentation dengan parameter berbeda """ prompt_path = os.path.join(PROMPT_FOLDER, f"{params.prompt}.wav") if not os.path.exists(prompt_path): return JSONResponse( status_code=404, content={"error": "Prompt tidak ditemukan"} ) try: wav = model.generate( params.text, audio_prompt_path=prompt_path, temperature=params.temperature, top_p=params.top_p, repetition_penalty=params.repetition_penalty, ) except Exception as e: return JSONResponse( status_code=500, content={"error": f"Gagal generate audio: {e}"} ) wav = enhance_audio(wav.cpu(), model.sr) buffer = io.BytesIO() ta.save(buffer, wav, model.sr, format="wav") buffer.seek(0) return StreamingResponse( buffer, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=output_custom.wav"} ) # ========================= # Root # ========================= @app.get("/") async def root(): return { "message": "Chatterbox TTS API - Enhanced Voice Quality", "version": "2.0", "improvements": [ "Optimized temperature & sampling for source similarity", "Enhanced audio processing for natural voice", "Reduced robotic artifacts", "Better emotion & intonation preservation", "Analog-style warmth processing" ] } # ========================= # Health check # ========================= @app.get("/health") async def health(): return { "status": "healthy", "model_loaded": True, "device": DEVICE, "settings": { "temperature": TEMPERATURE, "top_p": TOP_P, "repetition_penalty": REPETITION_PENALTY, } } if __name__ == "__main__": import uvicorn uvicorn.run("claude_clonev4:app", host="0.0.0.0", port=6007, reload=False)