Files
jenna-tools/xclonev2.py
bipproduction 822b68c10f tambahannya
2025-12-07 09:00:54 +08:00

268 lines
7.3 KiB
Python

import io
import os
import base64
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
import torch
import torchaudio as ta
import torchaudio.functional as F
from pydub import AudioSegment
from chatterbox.tts import ChatterboxTTS
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import asyncio
import uuid
# =========================
# CONFIG
# =========================
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
CHECKPOINT = "t3_cfg.safetensors"
DEVICE = "cpu"
TEMPERATURE = 0.85
TOP_P = 0.92
REPETITION_PENALTY = 1.15
AUDIO_GAIN_DB = 1.5
PROMPT_FOLDER = "prompt_source"
TEMP_FOLDER = "temp_parts"
os.makedirs(PROMPT_FOLDER, exist_ok=True)
os.makedirs(TEMP_FOLDER, exist_ok=True)
# Executor parallel max 5
semaphore = asyncio.Semaphore(5)
app = FastAPI(title="Chatterbox TTS Server")
# =========================
# TEXT SPLITTER
# =========================
def split_text(text: str, max_len=100):
parts = []
text = text.strip()
while len(text) > max_len:
cut_index = text.rfind(" ", 0, max_len)
if cut_index == -1:
cut_index = max_len
parts.append(text[:cut_index].strip())
text = text[cut_index:].strip()
if text:
parts.append(text)
return parts
# =========================
# Enhance audio
# =========================
def enhance_audio(wav, sr):
wav = wav / (wav.abs().max() + 1e-8)
noise_level = 0.0008
wav = wav + torch.randn_like(wav) * noise_level
threshold = 0.7
ratio = 3.0
mask = wav.abs() > threshold
wav = torch.where(
mask, torch.sign(wav) * (threshold + (wav.abs() - threshold) / ratio), wav
)
wav = F.highpass_biquad(wav, sr, cutoff_freq=80)
wav = F.lowpass_biquad(wav, sr, cutoff_freq=8000)
wav = F.gain(wav, gain_db=AUDIO_GAIN_DB)
peak = wav.abs().max().item()
if peak > 0:
wav = wav / peak * 0.93
return wav
# =========================
# LOAD MODEL
# =========================
print("Loading model...")
model = ChatterboxTTS.from_pretrained(device=DEVICE)
ckpt = hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT)
state = load_file(ckpt, device=DEVICE)
model.t3.to(DEVICE).load_state_dict(state)
model.t3.eval()
for m in model.t3.modules():
if hasattr(m, "training"):
m.training = False
print("Model ready.")
# =========================
# Helper generate per-part
# =========================
async def generate_part(text, prompt_path, part_id):
async with semaphore:
loop = asyncio.get_running_loop()
wav = await loop.run_in_executor(
None,
lambda: model.generate(
text,
audio_prompt_path=prompt_path,
temperature=TEMPERATURE,
top_p=TOP_P,
repetition_penalty=REPETITION_PENALTY,
),
)
wav = enhance_audio(wav.cpu(), model.sr)
temp_path = os.path.join(TEMP_FOLDER, f"{part_id}.wav")
ta.save(temp_path, wav, model.sr)
return temp_path
# =========================
# Merge WAV
# =========================
def merge_wav(files):
combined = AudioSegment.empty()
for f in files:
combined += AudioSegment.from_wav(f)
return combined
# =====================================================
# 2. TTS MULTI-PART + QUEUE
# =====================================================
@app.post("/tts")
async def tts(text: str = Form(...), prompt: str = Form(...)):
prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav")
if not os.path.exists(prompt_path):
return JSONResponse(status_code=404, content={"error": "Prompt tidak ditemukan"})
# split text
parts = split_text(text)
# generate unique prefix for temp file parts
uid = str(uuid.uuid4())
tasks = []
part_files = []
# create async tasks
for i, segment in enumerate(parts):
part_id = f"{uid}_{i}"
tasks.append(generate_part(segment, prompt_path, part_id))
# run queue (max 5 at once)
part_files = await asyncio.gather(*tasks)
# merge result
final_audio = merge_wav(part_files)
# convert to buffer
buf = io.BytesIO()
final_audio.export(buf, format="wav")
buf.seek(0)
# cleanup
for f in part_files:
if os.path.exists(f):
os.remove(f)
return StreamingResponse(
buf,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=final_output.wav"},
)
# =====================================================
# Prompt Management (tidak diubah)
# =====================================================
class RegisterPromptBase64(BaseModel):
prompt_name: str
base64_audio: str
@app.post("/register-prompt-base64")
async def register_prompt_base64(data: RegisterPromptBase64):
filename = f"{data.prompt_name}.wav"
path = os.path.join(PROMPT_FOLDER, filename)
try:
raw = base64.b64decode(data.base64_audio)
with open(path, "wb") as f:
f.write(raw)
return {"status": "ok", "file": filename}
except Exception as e:
return JSONResponse(status_code=400, content={"error": str(e)})
@app.post("/register-prompt-file")
async def register_prompt_file(prompt: UploadFile = File(...), name: str = Form(None)):
prompt_name = name if name else os.path.splitext(prompt.filename)[0]
save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
try:
with open(save_path, "wb") as f:
f.write(await prompt.read())
return {"status": "ok", "file": f"{prompt_name}.wav"}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.get("/list-prompt")
async def list_prompt():
try:
files = os.listdir(PROMPT_FOLDER)
wav_files = [f for f in files if f.lower().endswith(".wav")]
return {
"count": len(wav_files),
"prompts": wav_files,
"prompt_names": [os.path.splitext(f)[0] for f in wav_files],
}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
class DeletePrompt(BaseModel):
prompt_name: str
@app.post("/delete-prompt")
async def delete_prompt(data: DeletePrompt):
file_path = os.path.join(PROMPT_FOLDER, f"{data.prompt_name}.wav")
if not os.path.exists(file_path):
return JSONResponse(status_code=404, content={"error": "Prompt tidak ditemukan"})
os.remove(file_path)
return {"status": "ok"}
class RenamePrompt(BaseModel):
old_name: str
new_name: str
@app.post("/rename-prompt")
async def rename_prompt(data: RenamePrompt):
old_path = os.path.join(PROMPT_FOLDER, f"{data.old_name}.wav")
new_path = os.path.join(PROMPT_FOLDER, f"{data.new_name}.wav")
if not os.path.exists(old_path):
return JSONResponse(status_code=404, content={"error": "Prompt lama tidak ditemukan"})
if os.path.exists(new_path):
return JSONResponse(status_code=400, content={"error": "Nama baru sudah digunakan"})
os.rename(old_path, new_path)
return {"status": "ok"}
@app.get("/")
async def root():
return {"message": "Chatterbox TTS API ready with queue + multi-part!"}
if __name__ == "__main__":
import uvicorn
uvicorn.run("claude_clonev4:app", host="0.0.0.0", port=6007, reload=False)