tambahannya
This commit is contained in:
456
py/main.py
Normal file
456
py/main.py
Normal file
@@ -0,0 +1,456 @@
|
||||
import os
|
||||
import base64
|
||||
import uuid
|
||||
import asyncio
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from fastapi import FastAPI, UploadFile, File, Form, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tts_util import TTSConfig, TTSEngine
|
||||
|
||||
|
||||
# ===============================================
|
||||
# CONFIG
|
||||
# ===============================================
|
||||
PROMPT_FOLDER = "prompt_source"
|
||||
os.makedirs(PROMPT_FOLDER, exist_ok=True)
|
||||
|
||||
JOBS_FOLDER = "jobs"
|
||||
os.makedirs(JOBS_FOLDER, exist_ok=True)
|
||||
|
||||
OUTPUT_FOLDER = "output"
|
||||
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
|
||||
|
||||
# Job store
|
||||
job_store: Dict[str, dict] = {}
|
||||
|
||||
# Auto cleanup
|
||||
JOB_EXPIRE_SECONDS = 600
|
||||
|
||||
# FIFO queue and turbo workers
|
||||
TURBO_MODE = True
|
||||
TURBO_WORKERS = 3
|
||||
WORKER_THREADPOOL_MAX = 4
|
||||
|
||||
# Rate limiting
|
||||
RATE_LIMIT_TOKENS = 10
|
||||
RATE_LIMIT_WINDOW = 60
|
||||
|
||||
# Token buckets
|
||||
token_buckets: Dict[str, dict] = {}
|
||||
token_lock = threading.Lock()
|
||||
|
||||
# Executor and TTS engine
|
||||
thread_pool = ThreadPoolExecutor(max_workers=WORKER_THREADPOOL_MAX)
|
||||
config = TTSConfig()
|
||||
tts_engine = TTSEngine(config, thread_pool)
|
||||
|
||||
app = FastAPI(title="Chatterbox TTS Server - Turbo + FIFO + RateLimit + WAV")
|
||||
|
||||
|
||||
# ===============================================
|
||||
# RATE LIMIT UTILITIES
|
||||
# ===============================================
|
||||
def get_client_ip(request: Request) -> str:
|
||||
"""Get client IP from request"""
|
||||
xff = request.headers.get("x-forwarded-for")
|
||||
if xff:
|
||||
return xff.split(",")[0].strip()
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
|
||||
def allow_request_ip(ip: str) -> bool:
|
||||
"""Check if request from IP is allowed (token bucket)"""
|
||||
now = time.time()
|
||||
with token_lock:
|
||||
bucket = token_buckets.get(ip)
|
||||
if bucket is None:
|
||||
token_buckets[ip] = {"tokens": RATE_LIMIT_TOKENS - 1, "last": now}
|
||||
return True
|
||||
|
||||
# Refill tokens
|
||||
elapsed = now - bucket["last"]
|
||||
refill = (elapsed / RATE_LIMIT_WINDOW) * RATE_LIMIT_TOKENS
|
||||
if refill > 0:
|
||||
bucket["tokens"] = min(RATE_LIMIT_TOKENS, bucket["tokens"] + refill)
|
||||
bucket["last"] = now
|
||||
|
||||
if bucket["tokens"] >= 1:
|
||||
bucket["tokens"] -= 1
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# ===============================================
|
||||
# BACKGROUND WORKERS
|
||||
# ===============================================
|
||||
job_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
|
||||
async def worker_loop(worker_id: int):
|
||||
"""Background worker for processing TTS jobs"""
|
||||
while True:
|
||||
job_id = await job_queue.get()
|
||||
job = job_store.get(job_id)
|
||||
if job is None:
|
||||
job_queue.task_done()
|
||||
continue
|
||||
|
||||
# Mark as processing
|
||||
job["status"] = "processing"
|
||||
job["worker"] = worker_id
|
||||
job["timestamp"] = time.time()
|
||||
|
||||
try:
|
||||
prompt = job.get("prompt")
|
||||
text = job.get("text")
|
||||
prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav")
|
||||
|
||||
if not os.path.exists(prompt_path):
|
||||
job["status"] = "error"
|
||||
job["error"] = "Prompt tidak ditemukan"
|
||||
job["timestamp"] = time.time()
|
||||
job_queue.task_done()
|
||||
continue
|
||||
|
||||
# Generate audio
|
||||
out_wav = os.path.join(OUTPUT_FOLDER, f"{job_id}.wav")
|
||||
await tts_engine.generate_to_file(text, prompt_path, out_wav)
|
||||
|
||||
job["status"] = "done"
|
||||
job["result"] = out_wav
|
||||
job["timestamp"] = time.time()
|
||||
|
||||
except Exception as e:
|
||||
job["status"] = "error"
|
||||
job["error"] = str(e)
|
||||
job["timestamp"] = time.time()
|
||||
finally:
|
||||
job_queue.task_done()
|
||||
|
||||
|
||||
async def cleanup_worker():
|
||||
"""Background cleanup worker for expired jobs"""
|
||||
while True:
|
||||
now = time.time()
|
||||
expired = []
|
||||
|
||||
for jid, job in list(job_store.items()):
|
||||
if (
|
||||
job.get("status") == "done"
|
||||
and now - job.get("timestamp", 0) > JOB_EXPIRE_SECONDS
|
||||
):
|
||||
f = job.get("result")
|
||||
# if f and os.path.exists(f):
|
||||
# try:
|
||||
# os.remove(f)
|
||||
# except Exception:
|
||||
# pass
|
||||
expired.append(jid)
|
||||
|
||||
for jid in expired:
|
||||
job_store.pop(jid, None)
|
||||
|
||||
# Purge stale token buckets
|
||||
with token_lock:
|
||||
stale_ips = []
|
||||
for ip, b in token_buckets.items():
|
||||
if now - b.get("last", 0) > RATE_LIMIT_WINDOW * 10:
|
||||
stale_ips.append(ip)
|
||||
for ip in stale_ips:
|
||||
token_buckets.pop(ip, None)
|
||||
|
||||
await asyncio.sleep(30)
|
||||
|
||||
|
||||
# ===============================================
|
||||
# STARTUP
|
||||
# ===============================================
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
"""Initialize TTS engine and start background workers"""
|
||||
# Load TTS model
|
||||
tts_engine.load_model()
|
||||
|
||||
# Start cleanup worker
|
||||
asyncio.create_task(cleanup_worker())
|
||||
|
||||
# Start turbo workers
|
||||
worker_count = TURBO_WORKERS if TURBO_MODE else 1
|
||||
for i in range(worker_count):
|
||||
asyncio.create_task(worker_loop(i + 1))
|
||||
|
||||
|
||||
# ===============================================
|
||||
# PYDANTIC MODELS
|
||||
# ===============================================
|
||||
class RegisterPromptBase64(BaseModel):
|
||||
prompt_name: str
|
||||
base64_audio: str
|
||||
|
||||
|
||||
class DeletePrompt(BaseModel):
|
||||
prompt_name: str
|
||||
|
||||
|
||||
class RenamePrompt(BaseModel):
|
||||
old_name: str
|
||||
new_name: str
|
||||
|
||||
|
||||
# ===============================================
|
||||
# PROMPT MANAGEMENT ENDPOINTS
|
||||
# ===============================================
|
||||
@app.post("/register-prompt-base64")
|
||||
async def register_prompt_base64(data: RegisterPromptBase64):
|
||||
"""Register a voice prompt from base64 audio"""
|
||||
filename = f"{data.prompt_name}.wav"
|
||||
path = os.path.join(PROMPT_FOLDER, filename)
|
||||
try:
|
||||
raw = base64.b64decode(data.base64_audio)
|
||||
with open(path, "wb") as f:
|
||||
f.write(raw)
|
||||
return {"status": "ok", "file": filename}
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||||
|
||||
|
||||
@app.post("/register-prompt-file")
|
||||
async def register_prompt_file(prompt: UploadFile = File(...), name: str = Form(None)):
|
||||
"""Register a voice prompt from uploaded file"""
|
||||
prompt_name = name or os.path.splitext(prompt.filename)[0]
|
||||
save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
|
||||
try:
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(await prompt.read())
|
||||
return {"status": "ok", "file": f"{prompt_name}.wav"}
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
||||
|
||||
@app.get("/list-prompt")
|
||||
async def list_prompt():
|
||||
"""List all registered voice prompts"""
|
||||
lst = [f for f in os.listdir(PROMPT_FOLDER) if f.lower().endswith(".wav")]
|
||||
return {
|
||||
"count": len(lst),
|
||||
"prompts": lst,
|
||||
"prompt_names": [os.path.splitext(f)[0] for f in lst],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/delete-prompt")
|
||||
async def delete_prompt(data: DeletePrompt):
|
||||
"""Delete a voice prompt"""
|
||||
path = os.path.join(PROMPT_FOLDER, f"{data.prompt_name}.wav")
|
||||
if not os.path.exists(path):
|
||||
return JSONResponse(
|
||||
status_code=404, content={"error": "Prompt tidak ditemukan"}
|
||||
)
|
||||
try:
|
||||
os.remove(path)
|
||||
return {"status": "ok", "deleted": f"{data.prompt_name}.wav"}
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
||||
|
||||
@app.post("/rename-prompt")
|
||||
async def rename_prompt(data: RenamePrompt):
|
||||
"""Rename a voice prompt"""
|
||||
old = os.path.join(PROMPT_FOLDER, f"{data.old_name}.wav")
|
||||
new = os.path.join(PROMPT_FOLDER, f"{data.new_name}.wav")
|
||||
if not os.path.exists(old):
|
||||
return JSONResponse(
|
||||
status_code=404, content={"error": "Prompt lama tidak ditemukan"}
|
||||
)
|
||||
if os.path.exists(new):
|
||||
return JSONResponse(
|
||||
status_code=400, content={"error": "Nama baru sudah digunakan"}
|
||||
)
|
||||
try:
|
||||
os.rename(old, new)
|
||||
return {"status": "ok", "from": data.old_name, "to": data.new_name}
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
||||
|
||||
@app.post("/tts-async")
|
||||
async def tts_async(request: Request, text: str = Form(...), prompt: str = Form(...)):
|
||||
"""Asynchronous TTS - enqueue job and return job_id"""
|
||||
client_ip = get_client_ip(request)
|
||||
if not allow_request_ip(client_ip):
|
||||
return JSONResponse(status_code=429, content={"error": "rate limit exceeded"})
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
job_store[job_id] = {
|
||||
"status": "pending",
|
||||
"timestamp": time.time(),
|
||||
"prompt": prompt,
|
||||
"text": text,
|
||||
"client_ip": client_ip,
|
||||
}
|
||||
|
||||
# Enqueue (FIFO)
|
||||
await job_queue.put(job_id)
|
||||
|
||||
return {"status": "queued", "job_id": job_id, "check": f"/result/{job_id}"}
|
||||
|
||||
|
||||
@app.get("/result/{job_id}")
|
||||
async def tts_result(job_id: str):
|
||||
"""Get result of async TTS job"""
|
||||
job = job_store.get(job_id)
|
||||
if not job:
|
||||
return JSONResponse(
|
||||
status_code=404, content={"error": "Job ID tidak ditemukan"}
|
||||
)
|
||||
|
||||
# Still processing
|
||||
if job["status"] in ("pending", "processing"):
|
||||
return {
|
||||
"status": job["status"],
|
||||
"job_id": job_id,
|
||||
"worker": job.get("worker"),
|
||||
"timestamp": job.get("timestamp"),
|
||||
}
|
||||
|
||||
# Error
|
||||
if job["status"] == "error":
|
||||
return job
|
||||
|
||||
# Done - return file
|
||||
result_path = job.get("result")
|
||||
if not result_path or not os.path.exists(result_path):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"status": "error", "error": "File hasil tidak ditemukan"},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"status": "done", "job_id": job_id, "file": result_path},
|
||||
)
|
||||
|
||||
@app.get("/list-file")
|
||||
async def list_file():
|
||||
"""List all files inside OUTPUT_FOLDER"""
|
||||
try:
|
||||
files = [
|
||||
f for f in os.listdir(OUTPUT_FOLDER)
|
||||
if os.path.isfile(os.path.join(OUTPUT_FOLDER, f))
|
||||
]
|
||||
|
||||
# Hanya file WAV (sesuai output engine)
|
||||
wav_files = [f for f in files if f.lower().endswith(".wav")]
|
||||
|
||||
# Include metadata timestamp dari job_store
|
||||
detailed = []
|
||||
for f in wav_files:
|
||||
full_path = os.path.join(OUTPUT_FOLDER, f)
|
||||
size = os.path.getsize(full_path)
|
||||
|
||||
# Cari job yang terkait (jika ada)
|
||||
related_job = None
|
||||
for jid, job in job_store.items():
|
||||
if job.get("result") == full_path:
|
||||
related_job = {
|
||||
"job_id": jid,
|
||||
"status": job.get("status"),
|
||||
"timestamp": job.get("timestamp"),
|
||||
"prompt": job.get("prompt"),
|
||||
}
|
||||
break
|
||||
|
||||
detailed.append(
|
||||
{
|
||||
"file": f,
|
||||
"size_bytes": size,
|
||||
"path": full_path,
|
||||
"job": related_job,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"count": len(wav_files),
|
||||
"files": detailed,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
||||
|
||||
async def iterfile(file_path):
|
||||
with open(file_path, "rb") as f:
|
||||
chunk = f.read(4096)
|
||||
while chunk:
|
||||
yield chunk
|
||||
chunk = f.read(4096)
|
||||
|
||||
|
||||
@app.get("/file/{file_name}")
|
||||
async def get_output_file(file_name: str):
|
||||
if not file_name.endswith(".wav"):
|
||||
file_name = f"{file_name}.wav"
|
||||
file_path = os.path.join(OUTPUT_FOLDER, file_name)
|
||||
if not os.path.exists(file_path):
|
||||
return JSONResponse(status_code=404, content={"error": "File tidak ditemukan"})
|
||||
|
||||
return StreamingResponse(
|
||||
iterfile(file_path),
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": f"attachment; filename={file_name}"},
|
||||
)
|
||||
|
||||
|
||||
# ===============================================
|
||||
# FILE MANAGEMENT ENDPOINTS
|
||||
# ===============================================
|
||||
@app.delete("/rm/{filename}")
|
||||
async def remove_file(filename: str):
|
||||
"""Delete a single output file"""
|
||||
if not filename.endswith(".wav"):
|
||||
filename = f"{filename}.wav"
|
||||
path = os.path.join(OUTPUT_FOLDER, filename)
|
||||
|
||||
if not os.path.exists(path):
|
||||
return JSONResponse(status_code=404, content={"error": "File tidak ditemukan"})
|
||||
|
||||
try:
|
||||
os.remove(path)
|
||||
# Remove from job_store
|
||||
for jid, job in list(job_store.items()):
|
||||
if job.get("result") == path:
|
||||
job_store.pop(jid, None)
|
||||
return {"status": "ok", "deleted": filename}
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
||||
@app.post("/cleanup")
|
||||
async def manual_cleanup():
|
||||
"""Manual cleanup - remove all output files and clear done/error jobs"""
|
||||
removed = []
|
||||
for f in os.listdir(OUTPUT_FOLDER):
|
||||
fp = os.path.join(OUTPUT_FOLDER, f)
|
||||
if os.path.isfile(fp):
|
||||
try:
|
||||
os.remove(fp)
|
||||
removed.append(f)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Clear done/error jobs
|
||||
cleared = []
|
||||
for jid in list(job_store.keys()):
|
||||
if job_store[jid]["status"] in ("done", "error"):
|
||||
job_store.pop(jid, None)
|
||||
cleared.append(jid)
|
||||
|
||||
return {"status": "ok", "removed_files": removed, "jobs_cleared": cleared}
|
||||
322
py/op.py
Normal file
322
py/op.py
Normal file
@@ -0,0 +1,322 @@
|
||||
import io
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
import torchaudio as ta
|
||||
import torchaudio.functional as F
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import multiprocessing
|
||||
|
||||
|
||||
class TTSConfig:
|
||||
"""Configuration for TTS model and processing"""
|
||||
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
|
||||
CHECKPOINT = "t3_cfg.safetensors"
|
||||
DEVICE = "cpu"
|
||||
|
||||
# Optimized generation parameters for speed
|
||||
TEMPERATURE = 0.7
|
||||
TOP_P = 0.9
|
||||
REPETITION_PENALTY = 1.1
|
||||
|
||||
# Audio processing
|
||||
AUDIO_GAIN_DB = 0.8
|
||||
|
||||
# Performance settings
|
||||
USE_QUANTIZATION = True
|
||||
USE_TORCH_COMPILE = True
|
||||
SIMPLIFY_AUDIO_ENHANCEMENT = True
|
||||
ENABLE_CACHING = True
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
"""Audio enhancement utilities (optimized)"""
|
||||
|
||||
@staticmethod
|
||||
def generate_pink_noise_fast(shape, device):
|
||||
"""Generate pink noise for audio enhancement (vectorized)"""
|
||||
white = torch.randn(shape, device=device)
|
||||
|
||||
# Fast approximation using multi-scale filtering
|
||||
pink = white * 0.5
|
||||
|
||||
# Apply simple averaging for pink-ish spectrum
|
||||
if white.dim() == 1:
|
||||
white_2d = white.unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
white_2d = white.unsqueeze(0) if white.dim() == 2 else white
|
||||
|
||||
# Quick low-pass filtering approximation
|
||||
kernel_size = min(3, white_2d.shape[-1])
|
||||
if kernel_size >= 2:
|
||||
filtered = torch.nn.functional.avg_pool1d(
|
||||
white_2d,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=kernel_size//2
|
||||
)
|
||||
pink += filtered.squeeze(0) * 0.3 if white.dim() == 1 else filtered.squeeze(0)
|
||||
|
||||
return pink * 0.1
|
||||
|
||||
@staticmethod
|
||||
def enhance_audio_fast(wav, sr):
|
||||
"""Apply audio enhancements with optimized operations"""
|
||||
with torch.no_grad():
|
||||
# Normalize
|
||||
peak = wav.abs().max()
|
||||
if peak > 0:
|
||||
wav = wav / (peak + 1e-8) * 0.95
|
||||
|
||||
# Apply filters in sequence (no-grad mode for speed)
|
||||
wav = F.highpass_biquad(wav, sr, cutoff_freq=60)
|
||||
wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000)
|
||||
wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7)
|
||||
wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7)
|
||||
|
||||
# Vectorized compression (faster than loop)
|
||||
threshold = 0.6
|
||||
ratio = 2.5
|
||||
abs_wav = wav.abs()
|
||||
mask = abs_wav > threshold
|
||||
wav = torch.where(
|
||||
mask,
|
||||
torch.sign(wav) * (threshold + (abs_wav - threshold) / ratio),
|
||||
wav
|
||||
)
|
||||
|
||||
wav = torch.tanh(wav * 1.08)
|
||||
|
||||
# Add pink noise (fast version)
|
||||
wav = wav + AudioProcessor.generate_pink_noise_fast(wav.shape, wav.device) * 0.0003
|
||||
wav = F.gain(wav, gain_db=TTSConfig.AUDIO_GAIN_DB)
|
||||
|
||||
# Final normalization
|
||||
peak = wav.abs().max()
|
||||
if peak > 0:
|
||||
wav = wav / peak * 0.88
|
||||
|
||||
return wav
|
||||
|
||||
@staticmethod
|
||||
def enhance_audio_simple(wav, sr):
|
||||
"""Simplified audio enhancement for maximum speed"""
|
||||
with torch.no_grad():
|
||||
# Simple normalization and tanh saturation
|
||||
peak = wav.abs().max()
|
||||
if peak > 0:
|
||||
wav = wav / (peak + 1e-8) * 0.95
|
||||
|
||||
# Basic filtering
|
||||
wav = F.highpass_biquad(wav, sr, cutoff_freq=80)
|
||||
wav = F.lowpass_biquad(wav, sr, cutoff_freq=8000)
|
||||
|
||||
# Soft clipping
|
||||
wav = torch.tanh(wav * 1.1)
|
||||
|
||||
# Final normalization
|
||||
peak = wav.abs().max()
|
||||
if peak > 0:
|
||||
wav = wav / peak * 0.9
|
||||
|
||||
return wav
|
||||
|
||||
@staticmethod
|
||||
def save_tensor_to_wav(wav_tensor: torch.Tensor, sr: int, out_wav_path: str):
|
||||
"""Save a torch tensor to WAV file"""
|
||||
# Ensure float32 CPU tensor
|
||||
if wav_tensor.device.type != "cpu":
|
||||
wav_tensor = wav_tensor.cpu()
|
||||
if wav_tensor.dtype != torch.float32:
|
||||
wav_tensor = wav_tensor.type(torch.float32)
|
||||
|
||||
# torchaudio.save requires shape [channels, samples]
|
||||
if wav_tensor.dim() == 1:
|
||||
wav_out = wav_tensor.unsqueeze(0)
|
||||
else:
|
||||
wav_out = wav_tensor
|
||||
|
||||
# Save directly as WAV
|
||||
ta.save(out_wav_path, wav_out, sr, format="wav")
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_wav_buffer(wav_tensor: torch.Tensor, sr: int) -> io.BytesIO:
|
||||
"""Convert torch tensor to WAV buffer"""
|
||||
buf = io.BytesIO()
|
||||
if wav_tensor.dim() == 1:
|
||||
wav_out = wav_tensor.unsqueeze(0)
|
||||
else:
|
||||
wav_out = wav_tensor
|
||||
ta.save(buf, wav_out, sr, format="wav")
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
class TTSEngine:
|
||||
"""Main TTS engine with model management (optimized)"""
|
||||
|
||||
def __init__(self, config: TTSConfig, thread_pool: Optional[ThreadPoolExecutor] = None):
|
||||
self.config = config
|
||||
self.thread_pool = thread_pool or ThreadPoolExecutor(
|
||||
max_workers=multiprocessing.cpu_count()
|
||||
)
|
||||
self.model = None
|
||||
self.model_lock = asyncio.Lock()
|
||||
self.sr = None
|
||||
self.audio_prompt_cache = {} if config.ENABLE_CACHING else None
|
||||
|
||||
def load_model(self):
|
||||
"""Load the TTS model and checkpoint with optimizations"""
|
||||
print("Loading model...")
|
||||
self.model = ChatterboxTTS.from_pretrained(device=self.config.DEVICE)
|
||||
ckpt = hf_hub_download(repo_id=self.config.MODEL_REPO, filename=self.config.CHECKPOINT)
|
||||
state = load_file(ckpt, device=self.config.DEVICE)
|
||||
|
||||
self.model.t3.to(self.config.DEVICE).load_state_dict(state)
|
||||
self.model.t3.eval()
|
||||
|
||||
# Apply quantization for CPU speed
|
||||
if self.config.USE_QUANTIZATION:
|
||||
print("Applying dynamic quantization...")
|
||||
self.model.t3 = torch.quantization.quantize_dynamic(
|
||||
self.model.t3,
|
||||
{torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU},
|
||||
dtype=torch.qint8
|
||||
)
|
||||
|
||||
# Apply torch.compile if available (PyTorch 2.0+)
|
||||
if self.config.USE_TORCH_COMPILE and hasattr(torch, 'compile'):
|
||||
print("Compiling model with torch.compile...")
|
||||
try:
|
||||
self.model.t3 = torch.compile(self.model.t3, mode="reduce-overhead")
|
||||
except Exception as e:
|
||||
print(f"Torch compile failed: {e}, continuing without compilation")
|
||||
|
||||
# Disable dropout for inference
|
||||
for m in self.model.t3.modules():
|
||||
if hasattr(m, "training"):
|
||||
m.training = False
|
||||
if isinstance(m, torch.nn.Dropout):
|
||||
m.p = 0
|
||||
|
||||
self.sr = self.model.sr
|
||||
print("Model ready (optimized for CPU).")
|
||||
|
||||
def _load_audio_prompt(self, audio_prompt_path: str):
|
||||
"""Load audio prompt with optional caching"""
|
||||
if self.config.ENABLE_CACHING and audio_prompt_path in self.audio_prompt_cache:
|
||||
return self.audio_prompt_cache[audio_prompt_path]
|
||||
|
||||
# Load normally
|
||||
# Note: actual loading is done inside model.generate
|
||||
if self.config.ENABLE_CACHING:
|
||||
self.audio_prompt_cache[audio_prompt_path] = audio_prompt_path
|
||||
|
||||
return audio_prompt_path
|
||||
|
||||
async def generate(self, text: str, audio_prompt_path: str) -> torch.Tensor:
|
||||
"""Generate audio from text with voice prompt"""
|
||||
async with self.model_lock:
|
||||
# Cache audio prompt path
|
||||
cached_prompt = self._load_audio_prompt(audio_prompt_path)
|
||||
|
||||
def blocking_generate():
|
||||
with torch.no_grad():
|
||||
# Set number of threads for CPU inference
|
||||
torch.set_num_threads(multiprocessing.cpu_count())
|
||||
|
||||
return self.model.generate(
|
||||
text,
|
||||
audio_prompt_path=cached_prompt,
|
||||
temperature=self.config.TEMPERATURE,
|
||||
top_p=self.config.TOP_P,
|
||||
repetition_penalty=self.config.REPETITION_PENALTY,
|
||||
)
|
||||
|
||||
wav = await asyncio.get_event_loop().run_in_executor(
|
||||
self.thread_pool,
|
||||
blocking_generate
|
||||
)
|
||||
return wav
|
||||
|
||||
async def generate_and_enhance(self, text: str, audio_prompt_path: str) -> torch.Tensor:
|
||||
"""Generate and enhance audio"""
|
||||
wav = await self.generate(text, audio_prompt_path)
|
||||
|
||||
# Choose enhancement method based on config
|
||||
enhance_func = (
|
||||
AudioProcessor.enhance_audio_simple
|
||||
if self.config.SIMPLIFY_AUDIO_ENHANCEMENT
|
||||
else AudioProcessor.enhance_audio_fast
|
||||
)
|
||||
|
||||
# Enhance audio (CPU-bound)
|
||||
wav = await asyncio.get_event_loop().run_in_executor(
|
||||
self.thread_pool,
|
||||
lambda: enhance_func(wav.cpu(), self.sr)
|
||||
)
|
||||
|
||||
return wav
|
||||
|
||||
async def generate_to_file(self, text: str, audio_prompt_path: str, output_path: str):
|
||||
"""Generate audio and save to file"""
|
||||
wav = await self.generate_and_enhance(text, audio_prompt_path)
|
||||
|
||||
# Save to WAV
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
self.thread_pool,
|
||||
AudioProcessor.save_tensor_to_wav,
|
||||
wav,
|
||||
self.sr,
|
||||
output_path
|
||||
)
|
||||
|
||||
async def generate_to_buffer(self, text: str, audio_prompt_path: str) -> io.BytesIO:
|
||||
"""Generate audio and return as WAV buffer"""
|
||||
wav = await self.generate_and_enhance(text, audio_prompt_path)
|
||||
|
||||
# Convert to buffer
|
||||
buffer = await asyncio.get_event_loop().run_in_executor(
|
||||
self.thread_pool,
|
||||
AudioProcessor.tensor_to_wav_buffer,
|
||||
wav,
|
||||
self.sr
|
||||
)
|
||||
return buffer
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear audio prompt cache"""
|
||||
if self.audio_prompt_cache:
|
||||
self.audio_prompt_cache.clear()
|
||||
|
||||
|
||||
# Example usage
|
||||
async def main():
|
||||
"""Example usage of optimized TTS engine"""
|
||||
config = TTSConfig()
|
||||
engine = TTSEngine(config)
|
||||
|
||||
# Load model once
|
||||
engine.load_model()
|
||||
|
||||
# Generate audio
|
||||
text = "Halo, ini adalah tes text to speech dalam bahasa Indonesia."
|
||||
audio_prompt = "path/to/your/voice_sample.wav"
|
||||
|
||||
# Generate to file
|
||||
await engine.generate_to_file(text, audio_prompt, "output.wav")
|
||||
print("Audio generated successfully!")
|
||||
|
||||
# Or generate to buffer
|
||||
buffer = await engine.generate_to_buffer(text, audio_prompt)
|
||||
print(f"Audio buffer size: {len(buffer.getvalue())} bytes")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
213
py/tts_util.py
Normal file
213
py/tts_util.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import io
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torchaudio as ta
|
||||
import torchaudio.functional as F
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
class TTSConfig:
|
||||
"""Configuration for TTS model and processing"""
|
||||
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
|
||||
CHECKPOINT = "t3_cfg.safetensors"
|
||||
DEVICE = "cpu"
|
||||
|
||||
# Generation parameters
|
||||
TEMPERATURE = 0.65
|
||||
TOP_P = 0.88
|
||||
REPETITION_PENALTY = 1.25
|
||||
|
||||
# Audio processing
|
||||
AUDIO_GAIN_DB = 0.8
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
"""Audio enhancement utilities"""
|
||||
|
||||
@staticmethod
|
||||
def generate_pink_noise(shape, device):
|
||||
"""Generate pink noise for audio enhancement"""
|
||||
white = torch.randn(shape, device=device)
|
||||
pink = torch.zeros_like(white)
|
||||
b = torch.zeros(7)
|
||||
|
||||
if len(shape) == 1:
|
||||
for j in range(shape[0]):
|
||||
w = white[j].item()
|
||||
b[0] = 0.99886 * b[0] + w * 0.0555179
|
||||
b[1] = 0.99332 * b[1] + w * 0.0750759
|
||||
b[2] = 0.96900 * b[2] + w * 0.1538520
|
||||
b[3] = 0.86650 * b[3] + w * 0.3104856
|
||||
b[4] = 0.55000 * b[4] + w * 0.5329522
|
||||
b[5] = -0.7616 * b[5] - w * 0.0168980
|
||||
pink[j] = (b[0]+b[1]+b[2]+b[3]+b[4]+b[5]+b[6] + w*0.5362) * 0.11
|
||||
b[6] = w * 0.115926
|
||||
else:
|
||||
for i in range(shape[0]):
|
||||
b = torch.zeros(7)
|
||||
for j in range(shape[1]):
|
||||
w = white[i, j].item()
|
||||
b[0] = 0.99886 * b[0] + w * 0.0555179
|
||||
b[1] = 0.99332 * b[1] + w * 0.0750759
|
||||
b[2] = 0.96900 * b[2] + w * 0.1538520
|
||||
b[3] = 0.86650 * b[3] + w * 0.3104856
|
||||
b[4] = 0.55000 * b[4] + w * 0.5329522
|
||||
b[5] = -0.7616 * b[5] - w * 0.0168980
|
||||
pink[i, j] = (b[0]+b[1]+b[2]+b[3]+b[4]+b[5]+b[6] + w*0.5362) * 0.11
|
||||
b[6] = w * 0.115926
|
||||
|
||||
return pink * 0.1
|
||||
|
||||
@staticmethod
|
||||
def enhance_audio(wav, sr):
|
||||
"""Apply audio enhancements: normalization, filtering, compression"""
|
||||
# Normalize
|
||||
peak = wav.abs().max()
|
||||
if peak > 0:
|
||||
wav = wav / (peak + 1e-8) * 0.95
|
||||
|
||||
# Apply filters
|
||||
wav = F.highpass_biquad(wav, sr, cutoff_freq=60)
|
||||
wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000)
|
||||
wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7)
|
||||
wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7)
|
||||
|
||||
# Compression
|
||||
threshold = 0.6
|
||||
ratio = 2.5
|
||||
abs_wav = wav.abs()
|
||||
compressed = wav.clone()
|
||||
mask = abs_wav > threshold
|
||||
compressed[mask] = torch.sign(wav[mask]) * (threshold + (abs_wav[mask] - threshold) / ratio)
|
||||
|
||||
wav = compressed
|
||||
wav = torch.tanh(wav * 1.08)
|
||||
|
||||
# Add pink noise
|
||||
wav = wav + AudioProcessor.generate_pink_noise(wav.shape, wav.device) * 0.0003
|
||||
wav = F.gain(wav, gain_db=TTSConfig.AUDIO_GAIN_DB)
|
||||
|
||||
# Final normalization
|
||||
peak = wav.abs().max()
|
||||
if peak > 0:
|
||||
wav = wav / peak * 0.88
|
||||
|
||||
return wav
|
||||
|
||||
@staticmethod
|
||||
def save_tensor_to_wav(wav_tensor: torch.Tensor, sr: int, out_wav_path: str):
|
||||
"""Save a torch tensor to WAV file"""
|
||||
# Ensure float32 CPU tensor
|
||||
if wav_tensor.device.type != "cpu":
|
||||
wav_tensor = wav_tensor.cpu()
|
||||
if wav_tensor.dtype != torch.float32:
|
||||
wav_tensor = wav_tensor.type(torch.float32)
|
||||
|
||||
# torchaudio.save requires shape [channels, samples]
|
||||
if wav_tensor.dim() == 1:
|
||||
wav_out = wav_tensor.unsqueeze(0)
|
||||
else:
|
||||
wav_out = wav_tensor
|
||||
|
||||
# Save directly as WAV
|
||||
ta.save(out_wav_path, wav_out, sr, format="wav")
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_wav_buffer(wav_tensor: torch.Tensor, sr: int) -> io.BytesIO:
|
||||
"""Convert torch tensor to WAV buffer"""
|
||||
buf = io.BytesIO()
|
||||
if wav_tensor.dim() == 1:
|
||||
wav_out = wav_tensor.unsqueeze(0)
|
||||
else:
|
||||
wav_out = wav_tensor
|
||||
ta.save(buf, wav_out, sr, format="wav")
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
class TTSEngine:
|
||||
"""Main TTS engine with model management"""
|
||||
|
||||
def __init__(self, config: TTSConfig, thread_pool: ThreadPoolExecutor):
|
||||
self.config = config
|
||||
self.thread_pool = thread_pool
|
||||
self.model = None
|
||||
self.model_lock = asyncio.Lock()
|
||||
self.sr = None
|
||||
|
||||
def load_model(self):
|
||||
"""Load the TTS model and checkpoint"""
|
||||
print("Loading model...")
|
||||
self.model = ChatterboxTTS.from_pretrained(device=self.config.DEVICE)
|
||||
ckpt = hf_hub_download(repo_id=self.config.MODEL_REPO, filename=self.config.CHECKPOINT)
|
||||
state = load_file(ckpt, device=self.config.DEVICE)
|
||||
|
||||
self.model.t3.to(self.config.DEVICE).load_state_dict(state)
|
||||
self.model.t3.eval()
|
||||
|
||||
# Disable dropout
|
||||
for m in self.model.t3.modules():
|
||||
if hasattr(m, "training"):
|
||||
m.training = False
|
||||
if isinstance(m, torch.nn.Dropout):
|
||||
m.p = 0
|
||||
|
||||
self.sr = self.model.sr
|
||||
print("Model ready.")
|
||||
|
||||
async def generate(self, text: str, audio_prompt_path: str) -> torch.Tensor:
|
||||
"""Generate audio from text with voice prompt"""
|
||||
async with self.model_lock:
|
||||
def blocking_generate():
|
||||
with torch.no_grad():
|
||||
return self.model.generate(
|
||||
text,
|
||||
audio_prompt_path=audio_prompt_path,
|
||||
temperature=self.config.TEMPERATURE,
|
||||
top_p=self.config.TOP_P,
|
||||
repetition_penalty=self.config.REPETITION_PENALTY,
|
||||
)
|
||||
|
||||
wav = await asyncio.get_event_loop().run_in_executor(
|
||||
self.thread_pool,
|
||||
blocking_generate
|
||||
)
|
||||
return wav
|
||||
|
||||
async def generate_and_enhance(self, text: str, audio_prompt_path: str) -> torch.Tensor:
|
||||
"""Generate and enhance audio"""
|
||||
wav = await self.generate(text, audio_prompt_path)
|
||||
|
||||
# Enhance audio (CPU-bound)
|
||||
wav = await asyncio.get_event_loop().run_in_executor(
|
||||
self.thread_pool,
|
||||
lambda: AudioProcessor.enhance_audio(wav.cpu(), self.sr)
|
||||
)
|
||||
|
||||
return wav
|
||||
|
||||
async def generate_to_file(self, text: str, audio_prompt_path: str, output_path: str):
|
||||
"""Generate audio and save to file"""
|
||||
wav = await self.generate_and_enhance(text, audio_prompt_path)
|
||||
|
||||
# Save to WAV
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
self.thread_pool,
|
||||
AudioProcessor.save_tensor_to_wav,
|
||||
wav,
|
||||
self.sr,
|
||||
output_path
|
||||
)
|
||||
|
||||
async def generate_to_buffer(self, text: str, audio_prompt_path: str) -> io.BytesIO:
|
||||
"""Generate audio and return as WAV buffer"""
|
||||
wav = await self.generate_and_enhance(text, audio_prompt_path)
|
||||
|
||||
# Convert to buffer
|
||||
return AudioProcessor.tensor_to_wav_buffer(wav, self.sr)
|
||||
Reference in New Issue
Block a user