268 lines
7.3 KiB
Python
268 lines
7.3 KiB
Python
import io
|
|
import os
|
|
import base64
|
|
from fastapi import FastAPI, UploadFile, File, Form
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
import torch
|
|
import torchaudio as ta
|
|
import torchaudio.functional as F
|
|
from pydub import AudioSegment
|
|
|
|
from chatterbox.tts import ChatterboxTTS
|
|
from huggingface_hub import hf_hub_download
|
|
from safetensors.torch import load_file
|
|
|
|
import asyncio
|
|
import uuid
|
|
|
|
# =========================
|
|
# CONFIG
|
|
# =========================
|
|
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
|
|
CHECKPOINT = "t3_cfg.safetensors"
|
|
DEVICE = "cpu"
|
|
|
|
TEMPERATURE = 0.85
|
|
TOP_P = 0.92
|
|
REPETITION_PENALTY = 1.15
|
|
AUDIO_GAIN_DB = 1.5
|
|
|
|
PROMPT_FOLDER = "prompt_source"
|
|
TEMP_FOLDER = "temp_parts"
|
|
os.makedirs(PROMPT_FOLDER, exist_ok=True)
|
|
os.makedirs(TEMP_FOLDER, exist_ok=True)
|
|
|
|
# Executor parallel max 5
|
|
semaphore = asyncio.Semaphore(5)
|
|
|
|
app = FastAPI(title="Chatterbox TTS Server")
|
|
|
|
|
|
# =========================
|
|
# TEXT SPLITTER
|
|
# =========================
|
|
def split_text(text: str, max_len=100):
|
|
parts = []
|
|
text = text.strip()
|
|
|
|
while len(text) > max_len:
|
|
cut_index = text.rfind(" ", 0, max_len)
|
|
if cut_index == -1:
|
|
cut_index = max_len
|
|
parts.append(text[:cut_index].strip())
|
|
text = text[cut_index:].strip()
|
|
|
|
if text:
|
|
parts.append(text)
|
|
|
|
return parts
|
|
|
|
|
|
# =========================
|
|
# Enhance audio
|
|
# =========================
|
|
def enhance_audio(wav, sr):
|
|
wav = wav / (wav.abs().max() + 1e-8)
|
|
noise_level = 0.0008
|
|
wav = wav + torch.randn_like(wav) * noise_level
|
|
|
|
threshold = 0.7
|
|
ratio = 3.0
|
|
mask = wav.abs() > threshold
|
|
wav = torch.where(
|
|
mask, torch.sign(wav) * (threshold + (wav.abs() - threshold) / ratio), wav
|
|
)
|
|
|
|
wav = F.highpass_biquad(wav, sr, cutoff_freq=80)
|
|
wav = F.lowpass_biquad(wav, sr, cutoff_freq=8000)
|
|
wav = F.gain(wav, gain_db=AUDIO_GAIN_DB)
|
|
|
|
peak = wav.abs().max().item()
|
|
if peak > 0:
|
|
wav = wav / peak * 0.93
|
|
|
|
return wav
|
|
|
|
|
|
# =========================
|
|
# LOAD MODEL
|
|
# =========================
|
|
print("Loading model...")
|
|
|
|
model = ChatterboxTTS.from_pretrained(device=DEVICE)
|
|
ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT)
|
|
state = load_file(ckpt, device=DEVICE)
|
|
|
|
model.t3.to(DEVICE).load_state_dict(state)
|
|
model.t3.eval()
|
|
for m in model.t3.modules():
|
|
if hasattr(m, "training"):
|
|
m.training = False
|
|
|
|
print("Model ready.")
|
|
|
|
|
|
# =========================
|
|
# Helper generate per-part
|
|
# =========================
|
|
async def generate_part(text, prompt_path, part_id):
|
|
async with semaphore:
|
|
loop = asyncio.get_running_loop()
|
|
wav = await loop.run_in_executor(
|
|
None,
|
|
lambda: model.generate(
|
|
text,
|
|
audio_prompt_path=prompt_path,
|
|
temperature=TEMPERATURE,
|
|
top_p=TOP_P,
|
|
repetition_penalty=REPETITION_PENALTY,
|
|
),
|
|
)
|
|
|
|
wav = enhance_audio(wav.cpu(), model.sr)
|
|
|
|
temp_path = os.path.join(TEMP_FOLDER, f"{part_id}.wav")
|
|
ta.save(temp_path, wav, model.sr)
|
|
|
|
return temp_path
|
|
|
|
|
|
# =========================
|
|
# Merge WAV
|
|
# =========================
|
|
def merge_wav(files):
|
|
combined = AudioSegment.empty()
|
|
for f in files:
|
|
combined += AudioSegment.from_wav(f)
|
|
return combined
|
|
|
|
|
|
# =====================================================
|
|
# 2. TTS MULTI-PART + QUEUE
|
|
# =====================================================
|
|
@app.post("/tts")
|
|
async def tts(text: str = Form(...), prompt: str = Form(...)):
|
|
prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav")
|
|
if not os.path.exists(prompt_path):
|
|
return JSONResponse(status_code=404, content={"error": "Prompt tidak ditemukan"})
|
|
|
|
# split text
|
|
parts = split_text(text)
|
|
|
|
# generate unique prefix for temp file parts
|
|
uid = str(uuid.uuid4())
|
|
|
|
tasks = []
|
|
part_files = []
|
|
|
|
# create async tasks
|
|
for i, segment in enumerate(parts):
|
|
part_id = f"{uid}_{i}"
|
|
tasks.append(generate_part(segment, prompt_path, part_id))
|
|
|
|
# run queue (max 5 at once)
|
|
part_files = await asyncio.gather(*tasks)
|
|
|
|
# merge result
|
|
final_audio = merge_wav(part_files)
|
|
|
|
# convert to buffer
|
|
buf = io.BytesIO()
|
|
final_audio.export(buf, format="wav")
|
|
buf.seek(0)
|
|
|
|
# cleanup
|
|
for f in part_files:
|
|
if os.path.exists(f):
|
|
os.remove(f)
|
|
|
|
return StreamingResponse(
|
|
buf,
|
|
media_type="audio/wav",
|
|
headers={"Content-Disposition": "attachment; filename=final_output.wav"},
|
|
)
|
|
|
|
|
|
# =====================================================
|
|
# Prompt Management (tidak diubah)
|
|
# =====================================================
|
|
|
|
class RegisterPromptBase64(BaseModel):
|
|
prompt_name: str
|
|
base64_audio: str
|
|
|
|
@app.post("/register-prompt-base64")
|
|
async def register_prompt_base64(data: RegisterPromptBase64):
|
|
filename = f"{data.prompt_name}.wav"
|
|
path = os.path.join(PROMPT_FOLDER, filename)
|
|
try:
|
|
raw = base64.b64decode(data.base64_audio)
|
|
with open(path, "wb") as f:
|
|
f.write(raw)
|
|
return {"status": "ok", "file": filename}
|
|
except Exception as e:
|
|
return JSONResponse(status_code=400, content={"error": str(e)})
|
|
|
|
@app.post("/register-prompt-file")
|
|
async def register_prompt_file(prompt: UploadFile = File(...), name: str = Form(None)):
|
|
prompt_name = name if name else os.path.splitext(prompt.filename)[0]
|
|
save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
|
|
try:
|
|
with open(save_path, "wb") as f:
|
|
f.write(await prompt.read())
|
|
return {"status": "ok", "file": f"{prompt_name}.wav"}
|
|
except Exception as e:
|
|
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
@app.get("/list-prompt")
|
|
async def list_prompt():
|
|
try:
|
|
files = os.listdir(PROMPT_FOLDER)
|
|
wav_files = [f for f in files if f.lower().endswith(".wav")]
|
|
return {
|
|
"count": len(wav_files),
|
|
"prompts": wav_files,
|
|
"prompt_names": [os.path.splitext(f)[0] for f in wav_files],
|
|
}
|
|
except Exception as e:
|
|
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
|
|
class DeletePrompt(BaseModel):
|
|
prompt_name: str
|
|
|
|
@app.post("/delete-prompt")
|
|
async def delete_prompt(data: DeletePrompt):
|
|
file_path = os.path.join(PROMPT_FOLDER, f"{data.prompt_name}.wav")
|
|
if not os.path.exists(file_path):
|
|
return JSONResponse(status_code=404, content={"error": "Prompt tidak ditemukan"})
|
|
os.remove(file_path)
|
|
return {"status": "ok"}
|
|
|
|
class RenamePrompt(BaseModel):
|
|
old_name: str
|
|
new_name: str
|
|
|
|
@app.post("/rename-prompt")
|
|
async def rename_prompt(data: RenamePrompt):
|
|
old_path = os.path.join(PROMPT_FOLDER, f"{data.old_name}.wav")
|
|
new_path = os.path.join(PROMPT_FOLDER, f"{data.new_name}.wav")
|
|
if not os.path.exists(old_path):
|
|
return JSONResponse(status_code=404, content={"error": "Prompt lama tidak ditemukan"})
|
|
if os.path.exists(new_path):
|
|
return JSONResponse(status_code=400, content={"error": "Nama baru sudah digunakan"})
|
|
os.rename(old_path, new_path)
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "Chatterbox TTS API ready with queue + multi-part!"}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run("claude_clonev4:app", host="0.0.0.0", port=6007, reload=False)
|