import io import os import base64 from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel import torch import torchaudio as ta import torchaudio.functional as F from pydub import AudioSegment from chatterbox.tts import ChatterboxTTS from huggingface_hub import hf_hub_download from safetensors.torch import load_file import asyncio import uuid # ========================= # CONFIG # ========================= MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian" CHECKPOINT = "t3_cfg.safetensors" DEVICE = "cpu" TEMPERATURE = 0.85 TOP_P = 0.92 REPETITION_PENALTY = 1.15 AUDIO_GAIN_DB = 1.5 PROMPT_FOLDER = "prompt_source" TEMP_FOLDER = "temp_parts" os.makedirs(PROMPT_FOLDER, exist_ok=True) os.makedirs(TEMP_FOLDER, exist_ok=True) # Executor parallel max 5 semaphore = asyncio.Semaphore(5) app = FastAPI(title="Chatterbox TTS Server") # ========================= # TEXT SPLITTER # ========================= def split_text(text: str, max_len=100): parts = [] text = text.strip() while len(text) > max_len: cut_index = text.rfind(" ", 0, max_len) if cut_index == -1: cut_index = max_len parts.append(text[:cut_index].strip()) text = text[cut_index:].strip() if text: parts.append(text) return parts # ========================= # Enhance audio # ========================= def enhance_audio(wav, sr): wav = wav / (wav.abs().max() + 1e-8) noise_level = 0.0008 wav = wav + torch.randn_like(wav) * noise_level threshold = 0.7 ratio = 3.0 mask = wav.abs() > threshold wav = torch.where( mask, torch.sign(wav) * (threshold + (wav.abs() - threshold) / ratio), wav ) wav = F.highpass_biquad(wav, sr, cutoff_freq=80) wav = F.lowpass_biquad(wav, sr, cutoff_freq=8000) wav = F.gain(wav, gain_db=AUDIO_GAIN_DB) peak = wav.abs().max().item() if peak > 0: wav = wav / peak * 0.93 return wav # ========================= # LOAD MODEL # ========================= print("Loading model...") model = ChatterboxTTS.from_pretrained(device=DEVICE) ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT) state = load_file(ckpt, device=DEVICE) model.t3.to(DEVICE).load_state_dict(state) model.t3.eval() for m in model.t3.modules(): if hasattr(m, "training"): m.training = False print("Model ready.") # ========================= # Helper generate per-part # ========================= async def generate_part(text, prompt_path, part_id): async with semaphore: loop = asyncio.get_running_loop() wav = await loop.run_in_executor( None, lambda: model.generate( text, audio_prompt_path=prompt_path, temperature=TEMPERATURE, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY, ), ) wav = enhance_audio(wav.cpu(), model.sr) temp_path = os.path.join(TEMP_FOLDER, f"{part_id}.wav") ta.save(temp_path, wav, model.sr) return temp_path # ========================= # Merge WAV # ========================= def merge_wav(files): combined = AudioSegment.empty() for f in files: combined += AudioSegment.from_wav(f) return combined # ===================================================== # 2. TTS MULTI-PART + QUEUE # ===================================================== @app.post("/tts") async def tts(text: str = Form(...), prompt: str = Form(...)): prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav") if not os.path.exists(prompt_path): return JSONResponse(status_code=404, content={"error": "Prompt tidak ditemukan"}) # split text parts = split_text(text) # generate unique prefix for temp file parts uid = str(uuid.uuid4()) tasks = [] part_files = [] # create async tasks for i, segment in enumerate(parts): part_id = f"{uid}_{i}" tasks.append(generate_part(segment, prompt_path, part_id)) # run queue (max 5 at once) part_files = await asyncio.gather(*tasks) # merge result final_audio = merge_wav(part_files) # convert to buffer buf = io.BytesIO() final_audio.export(buf, format="wav") buf.seek(0) # cleanup for f in part_files: if os.path.exists(f): os.remove(f) return StreamingResponse( buf, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=final_output.wav"}, ) # ===================================================== # Prompt Management (tidak diubah) # ===================================================== class RegisterPromptBase64(BaseModel): prompt_name: str base64_audio: str @app.post("/register-prompt-base64") async def register_prompt_base64(data: RegisterPromptBase64): filename = f"{data.prompt_name}.wav" path = os.path.join(PROMPT_FOLDER, filename) try: raw = base64.b64decode(data.base64_audio) with open(path, "wb") as f: f.write(raw) return {"status": "ok", "file": filename} except Exception as e: return JSONResponse(status_code=400, content={"error": str(e)}) @app.post("/register-prompt-file") async def register_prompt_file(prompt: UploadFile = File(...), name: str = Form(None)): prompt_name = name if name else os.path.splitext(prompt.filename)[0] save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav") try: with open(save_path, "wb") as f: f.write(await prompt.read()) return {"status": "ok", "file": f"{prompt_name}.wav"} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) @app.get("/list-prompt") async def list_prompt(): try: files = os.listdir(PROMPT_FOLDER) wav_files = [f for f in files if f.lower().endswith(".wav")] return { "count": len(wav_files), "prompts": wav_files, "prompt_names": [os.path.splitext(f)[0] for f in wav_files], } except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) class DeletePrompt(BaseModel): prompt_name: str @app.post("/delete-prompt") async def delete_prompt(data: DeletePrompt): file_path = os.path.join(PROMPT_FOLDER, f"{data.prompt_name}.wav") if not os.path.exists(file_path): return JSONResponse(status_code=404, content={"error": "Prompt tidak ditemukan"}) os.remove(file_path) return {"status": "ok"} class RenamePrompt(BaseModel): old_name: str new_name: str @app.post("/rename-prompt") async def rename_prompt(data: RenamePrompt): old_path = os.path.join(PROMPT_FOLDER, f"{data.old_name}.wav") new_path = os.path.join(PROMPT_FOLDER, f"{data.new_name}.wav") if not os.path.exists(old_path): return JSONResponse(status_code=404, content={"error": "Prompt lama tidak ditemukan"}) if os.path.exists(new_path): return JSONResponse(status_code=400, content={"error": "Nama baru sudah digunakan"}) os.rename(old_path, new_path) return {"status": "ok"} @app.get("/") async def root(): return {"message": "Chatterbox TTS API ready with queue + multi-part!"} if __name__ == "__main__": import uvicorn uvicorn.run("claude_clonev4:app", host="0.0.0.0", port=6007, reload=False)