tambahannya

This commit is contained in:
bipproduction
2025-12-07 09:00:54 +08:00
commit 822b68c10f
89 changed files with 16999 additions and 0 deletions

456
py/main.py Normal file
View File

@@ -0,0 +1,456 @@
import os
import base64
import uuid
import asyncio
import time
import threading
from typing import Dict
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, UploadFile, File, Form, Request
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from tts_util import TTSConfig, TTSEngine
# ===============================================
# CONFIG
# ===============================================
PROMPT_FOLDER = "prompt_source"
os.makedirs(PROMPT_FOLDER, exist_ok=True)
JOBS_FOLDER = "jobs"
os.makedirs(JOBS_FOLDER, exist_ok=True)
OUTPUT_FOLDER = "output"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
# Job store
job_store: Dict[str, dict] = {}
# Auto cleanup
JOB_EXPIRE_SECONDS = 600
# FIFO queue and turbo workers
TURBO_MODE = True
TURBO_WORKERS = 3
WORKER_THREADPOOL_MAX = 4
# Rate limiting
RATE_LIMIT_TOKENS = 10
RATE_LIMIT_WINDOW = 60
# Token buckets
token_buckets: Dict[str, dict] = {}
token_lock = threading.Lock()
# Executor and TTS engine
thread_pool = ThreadPoolExecutor(max_workers=WORKER_THREADPOOL_MAX)
config = TTSConfig()
tts_engine = TTSEngine(config, thread_pool)
app = FastAPI(title="Chatterbox TTS Server - Turbo + FIFO + RateLimit + WAV")
# ===============================================
# RATE LIMIT UTILITIES
# ===============================================
def get_client_ip(request: Request) -> str:
"""Get client IP from request"""
xff = request.headers.get("x-forwarded-for")
if xff:
return xff.split(",")[0].strip()
if request.client:
return request.client.host
return "unknown"
def allow_request_ip(ip: str) -> bool:
"""Check if request from IP is allowed (token bucket)"""
now = time.time()
with token_lock:
bucket = token_buckets.get(ip)
if bucket is None:
token_buckets[ip] = {"tokens": RATE_LIMIT_TOKENS - 1, "last": now}
return True
# Refill tokens
elapsed = now - bucket["last"]
refill = (elapsed / RATE_LIMIT_WINDOW) * RATE_LIMIT_TOKENS
if refill > 0:
bucket["tokens"] = min(RATE_LIMIT_TOKENS, bucket["tokens"] + refill)
bucket["last"] = now
if bucket["tokens"] >= 1:
bucket["tokens"] -= 1
return True
else:
return False
# ===============================================
# BACKGROUND WORKERS
# ===============================================
job_queue: asyncio.Queue = asyncio.Queue()
async def worker_loop(worker_id: int):
"""Background worker for processing TTS jobs"""
while True:
job_id = await job_queue.get()
job = job_store.get(job_id)
if job is None:
job_queue.task_done()
continue
# Mark as processing
job["status"] = "processing"
job["worker"] = worker_id
job["timestamp"] = time.time()
try:
prompt = job.get("prompt")
text = job.get("text")
prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav")
if not os.path.exists(prompt_path):
job["status"] = "error"
job["error"] = "Prompt tidak ditemukan"
job["timestamp"] = time.time()
job_queue.task_done()
continue
# Generate audio
out_wav = os.path.join(OUTPUT_FOLDER, f"{job_id}.wav")
await tts_engine.generate_to_file(text, prompt_path, out_wav)
job["status"] = "done"
job["result"] = out_wav
job["timestamp"] = time.time()
except Exception as e:
job["status"] = "error"
job["error"] = str(e)
job["timestamp"] = time.time()
finally:
job_queue.task_done()
async def cleanup_worker():
"""Background cleanup worker for expired jobs"""
while True:
now = time.time()
expired = []
for jid, job in list(job_store.items()):
if (
job.get("status") == "done"
and now - job.get("timestamp", 0) > JOB_EXPIRE_SECONDS
):
f = job.get("result")
# if f and os.path.exists(f):
# try:
# os.remove(f)
# except Exception:
# pass
expired.append(jid)
for jid in expired:
job_store.pop(jid, None)
# Purge stale token buckets
with token_lock:
stale_ips = []
for ip, b in token_buckets.items():
if now - b.get("last", 0) > RATE_LIMIT_WINDOW * 10:
stale_ips.append(ip)
for ip in stale_ips:
token_buckets.pop(ip, None)
await asyncio.sleep(30)
# ===============================================
# STARTUP
# ===============================================
@app.on_event("startup")
async def startup():
"""Initialize TTS engine and start background workers"""
# Load TTS model
tts_engine.load_model()
# Start cleanup worker
asyncio.create_task(cleanup_worker())
# Start turbo workers
worker_count = TURBO_WORKERS if TURBO_MODE else 1
for i in range(worker_count):
asyncio.create_task(worker_loop(i + 1))
# ===============================================
# PYDANTIC MODELS
# ===============================================
class RegisterPromptBase64(BaseModel):
prompt_name: str
base64_audio: str
class DeletePrompt(BaseModel):
prompt_name: str
class RenamePrompt(BaseModel):
old_name: str
new_name: str
# ===============================================
# PROMPT MANAGEMENT ENDPOINTS
# ===============================================
@app.post("/register-prompt-base64")
async def register_prompt_base64(data: RegisterPromptBase64):
"""Register a voice prompt from base64 audio"""
filename = f"{data.prompt_name}.wav"
path = os.path.join(PROMPT_FOLDER, filename)
try:
raw = base64.b64decode(data.base64_audio)
with open(path, "wb") as f:
f.write(raw)
return {"status": "ok", "file": filename}
except Exception as e:
return JSONResponse(status_code=400, content={"error": str(e)})
@app.post("/register-prompt-file")
async def register_prompt_file(prompt: UploadFile = File(...), name: str = Form(None)):
"""Register a voice prompt from uploaded file"""
prompt_name = name or os.path.splitext(prompt.filename)[0]
save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
try:
with open(save_path, "wb") as f:
f.write(await prompt.read())
return {"status": "ok", "file": f"{prompt_name}.wav"}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.get("/list-prompt")
async def list_prompt():
"""List all registered voice prompts"""
lst = [f for f in os.listdir(PROMPT_FOLDER) if f.lower().endswith(".wav")]
return {
"count": len(lst),
"prompts": lst,
"prompt_names": [os.path.splitext(f)[0] for f in lst],
}
@app.post("/delete-prompt")
async def delete_prompt(data: DeletePrompt):
"""Delete a voice prompt"""
path = os.path.join(PROMPT_FOLDER, f"{data.prompt_name}.wav")
if not os.path.exists(path):
return JSONResponse(
status_code=404, content={"error": "Prompt tidak ditemukan"}
)
try:
os.remove(path)
return {"status": "ok", "deleted": f"{data.prompt_name}.wav"}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/rename-prompt")
async def rename_prompt(data: RenamePrompt):
"""Rename a voice prompt"""
old = os.path.join(PROMPT_FOLDER, f"{data.old_name}.wav")
new = os.path.join(PROMPT_FOLDER, f"{data.new_name}.wav")
if not os.path.exists(old):
return JSONResponse(
status_code=404, content={"error": "Prompt lama tidak ditemukan"}
)
if os.path.exists(new):
return JSONResponse(
status_code=400, content={"error": "Nama baru sudah digunakan"}
)
try:
os.rename(old, new)
return {"status": "ok", "from": data.old_name, "to": data.new_name}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/tts-async")
async def tts_async(request: Request, text: str = Form(...), prompt: str = Form(...)):
"""Asynchronous TTS - enqueue job and return job_id"""
client_ip = get_client_ip(request)
if not allow_request_ip(client_ip):
return JSONResponse(status_code=429, content={"error": "rate limit exceeded"})
job_id = str(uuid.uuid4())
job_store[job_id] = {
"status": "pending",
"timestamp": time.time(),
"prompt": prompt,
"text": text,
"client_ip": client_ip,
}
# Enqueue (FIFO)
await job_queue.put(job_id)
return {"status": "queued", "job_id": job_id, "check": f"/result/{job_id}"}
@app.get("/result/{job_id}")
async def tts_result(job_id: str):
"""Get result of async TTS job"""
job = job_store.get(job_id)
if not job:
return JSONResponse(
status_code=404, content={"error": "Job ID tidak ditemukan"}
)
# Still processing
if job["status"] in ("pending", "processing"):
return {
"status": job["status"],
"job_id": job_id,
"worker": job.get("worker"),
"timestamp": job.get("timestamp"),
}
# Error
if job["status"] == "error":
return job
# Done - return file
result_path = job.get("result")
if not result_path or not os.path.exists(result_path):
return JSONResponse(
status_code=500,
content={"status": "error", "error": "File hasil tidak ditemukan"},
)
return JSONResponse(
status_code=200,
content={"status": "done", "job_id": job_id, "file": result_path},
)
@app.get("/list-file")
async def list_file():
"""List all files inside OUTPUT_FOLDER"""
try:
files = [
f for f in os.listdir(OUTPUT_FOLDER)
if os.path.isfile(os.path.join(OUTPUT_FOLDER, f))
]
# Hanya file WAV (sesuai output engine)
wav_files = [f for f in files if f.lower().endswith(".wav")]
# Include metadata timestamp dari job_store
detailed = []
for f in wav_files:
full_path = os.path.join(OUTPUT_FOLDER, f)
size = os.path.getsize(full_path)
# Cari job yang terkait (jika ada)
related_job = None
for jid, job in job_store.items():
if job.get("result") == full_path:
related_job = {
"job_id": jid,
"status": job.get("status"),
"timestamp": job.get("timestamp"),
"prompt": job.get("prompt"),
}
break
detailed.append(
{
"file": f,
"size_bytes": size,
"path": full_path,
"job": related_job,
}
)
return {
"count": len(wav_files),
"files": detailed,
}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
async def iterfile(file_path):
with open(file_path, "rb") as f:
chunk = f.read(4096)
while chunk:
yield chunk
chunk = f.read(4096)
@app.get("/file/{file_name}")
async def get_output_file(file_name: str):
if not file_name.endswith(".wav"):
file_name = f"{file_name}.wav"
file_path = os.path.join(OUTPUT_FOLDER, file_name)
if not os.path.exists(file_path):
return JSONResponse(status_code=404, content={"error": "File tidak ditemukan"})
return StreamingResponse(
iterfile(file_path),
media_type="audio/wav",
headers={"Content-Disposition": f"attachment; filename={file_name}"},
)
# ===============================================
# FILE MANAGEMENT ENDPOINTS
# ===============================================
@app.delete("/rm/{filename}")
async def remove_file(filename: str):
"""Delete a single output file"""
if not filename.endswith(".wav"):
filename = f"{filename}.wav"
path = os.path.join(OUTPUT_FOLDER, filename)
if not os.path.exists(path):
return JSONResponse(status_code=404, content={"error": "File tidak ditemukan"})
try:
os.remove(path)
# Remove from job_store
for jid, job in list(job_store.items()):
if job.get("result") == path:
job_store.pop(jid, None)
return {"status": "ok", "deleted": filename}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/cleanup")
async def manual_cleanup():
"""Manual cleanup - remove all output files and clear done/error jobs"""
removed = []
for f in os.listdir(OUTPUT_FOLDER):
fp = os.path.join(OUTPUT_FOLDER, f)
if os.path.isfile(fp):
try:
os.remove(fp)
removed.append(f)
except:
pass
# Clear done/error jobs
cleared = []
for jid in list(job_store.keys()):
if job_store[jid]["status"] in ("done", "error"):
job_store.pop(jid, None)
cleared.append(jid)
return {"status": "ok", "removed_files": removed, "jobs_cleared": cleared}

322
py/op.py Normal file
View File

@@ -0,0 +1,322 @@
import io
import os
import asyncio
from typing import Optional
from functools import lru_cache
import torch
import torchaudio as ta
import torchaudio.functional as F
from chatterbox.tts import ChatterboxTTS
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
class TTSConfig:
"""Configuration for TTS model and processing"""
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
CHECKPOINT = "t3_cfg.safetensors"
DEVICE = "cpu"
# Optimized generation parameters for speed
TEMPERATURE = 0.7
TOP_P = 0.9
REPETITION_PENALTY = 1.1
# Audio processing
AUDIO_GAIN_DB = 0.8
# Performance settings
USE_QUANTIZATION = True
USE_TORCH_COMPILE = True
SIMPLIFY_AUDIO_ENHANCEMENT = True
ENABLE_CACHING = True
class AudioProcessor:
"""Audio enhancement utilities (optimized)"""
@staticmethod
def generate_pink_noise_fast(shape, device):
"""Generate pink noise for audio enhancement (vectorized)"""
white = torch.randn(shape, device=device)
# Fast approximation using multi-scale filtering
pink = white * 0.5
# Apply simple averaging for pink-ish spectrum
if white.dim() == 1:
white_2d = white.unsqueeze(0).unsqueeze(0)
else:
white_2d = white.unsqueeze(0) if white.dim() == 2 else white
# Quick low-pass filtering approximation
kernel_size = min(3, white_2d.shape[-1])
if kernel_size >= 2:
filtered = torch.nn.functional.avg_pool1d(
white_2d,
kernel_size=kernel_size,
stride=1,
padding=kernel_size//2
)
pink += filtered.squeeze(0) * 0.3 if white.dim() == 1 else filtered.squeeze(0)
return pink * 0.1
@staticmethod
def enhance_audio_fast(wav, sr):
"""Apply audio enhancements with optimized operations"""
with torch.no_grad():
# Normalize
peak = wav.abs().max()
if peak > 0:
wav = wav / (peak + 1e-8) * 0.95
# Apply filters in sequence (no-grad mode for speed)
wav = F.highpass_biquad(wav, sr, cutoff_freq=60)
wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000)
wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7)
wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7)
# Vectorized compression (faster than loop)
threshold = 0.6
ratio = 2.5
abs_wav = wav.abs()
mask = abs_wav > threshold
wav = torch.where(
mask,
torch.sign(wav) * (threshold + (abs_wav - threshold) / ratio),
wav
)
wav = torch.tanh(wav * 1.08)
# Add pink noise (fast version)
wav = wav + AudioProcessor.generate_pink_noise_fast(wav.shape, wav.device) * 0.0003
wav = F.gain(wav, gain_db=TTSConfig.AUDIO_GAIN_DB)
# Final normalization
peak = wav.abs().max()
if peak > 0:
wav = wav / peak * 0.88
return wav
@staticmethod
def enhance_audio_simple(wav, sr):
"""Simplified audio enhancement for maximum speed"""
with torch.no_grad():
# Simple normalization and tanh saturation
peak = wav.abs().max()
if peak > 0:
wav = wav / (peak + 1e-8) * 0.95
# Basic filtering
wav = F.highpass_biquad(wav, sr, cutoff_freq=80)
wav = F.lowpass_biquad(wav, sr, cutoff_freq=8000)
# Soft clipping
wav = torch.tanh(wav * 1.1)
# Final normalization
peak = wav.abs().max()
if peak > 0:
wav = wav / peak * 0.9
return wav
@staticmethod
def save_tensor_to_wav(wav_tensor: torch.Tensor, sr: int, out_wav_path: str):
"""Save a torch tensor to WAV file"""
# Ensure float32 CPU tensor
if wav_tensor.device.type != "cpu":
wav_tensor = wav_tensor.cpu()
if wav_tensor.dtype != torch.float32:
wav_tensor = wav_tensor.type(torch.float32)
# torchaudio.save requires shape [channels, samples]
if wav_tensor.dim() == 1:
wav_out = wav_tensor.unsqueeze(0)
else:
wav_out = wav_tensor
# Save directly as WAV
ta.save(out_wav_path, wav_out, sr, format="wav")
@staticmethod
def tensor_to_wav_buffer(wav_tensor: torch.Tensor, sr: int) -> io.BytesIO:
"""Convert torch tensor to WAV buffer"""
buf = io.BytesIO()
if wav_tensor.dim() == 1:
wav_out = wav_tensor.unsqueeze(0)
else:
wav_out = wav_tensor
ta.save(buf, wav_out, sr, format="wav")
buf.seek(0)
return buf
class TTSEngine:
"""Main TTS engine with model management (optimized)"""
def __init__(self, config: TTSConfig, thread_pool: Optional[ThreadPoolExecutor] = None):
self.config = config
self.thread_pool = thread_pool or ThreadPoolExecutor(
max_workers=multiprocessing.cpu_count()
)
self.model = None
self.model_lock = asyncio.Lock()
self.sr = None
self.audio_prompt_cache = {} if config.ENABLE_CACHING else None
def load_model(self):
"""Load the TTS model and checkpoint with optimizations"""
print("Loading model...")
self.model = ChatterboxTTS.from_pretrained(device=self.config.DEVICE)
ckpt = hf_hub_download(repo_id=self.config.MODEL_REPO, filename=self.config.CHECKPOINT)
state = load_file(ckpt, device=self.config.DEVICE)
self.model.t3.to(self.config.DEVICE).load_state_dict(state)
self.model.t3.eval()
# Apply quantization for CPU speed
if self.config.USE_QUANTIZATION:
print("Applying dynamic quantization...")
self.model.t3 = torch.quantization.quantize_dynamic(
self.model.t3,
{torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU},
dtype=torch.qint8
)
# Apply torch.compile if available (PyTorch 2.0+)
if self.config.USE_TORCH_COMPILE and hasattr(torch, 'compile'):
print("Compiling model with torch.compile...")
try:
self.model.t3 = torch.compile(self.model.t3, mode="reduce-overhead")
except Exception as e:
print(f"Torch compile failed: {e}, continuing without compilation")
# Disable dropout for inference
for m in self.model.t3.modules():
if hasattr(m, "training"):
m.training = False
if isinstance(m, torch.nn.Dropout):
m.p = 0
self.sr = self.model.sr
print("Model ready (optimized for CPU).")
def _load_audio_prompt(self, audio_prompt_path: str):
"""Load audio prompt with optional caching"""
if self.config.ENABLE_CACHING and audio_prompt_path in self.audio_prompt_cache:
return self.audio_prompt_cache[audio_prompt_path]
# Load normally
# Note: actual loading is done inside model.generate
if self.config.ENABLE_CACHING:
self.audio_prompt_cache[audio_prompt_path] = audio_prompt_path
return audio_prompt_path
async def generate(self, text: str, audio_prompt_path: str) -> torch.Tensor:
"""Generate audio from text with voice prompt"""
async with self.model_lock:
# Cache audio prompt path
cached_prompt = self._load_audio_prompt(audio_prompt_path)
def blocking_generate():
with torch.no_grad():
# Set number of threads for CPU inference
torch.set_num_threads(multiprocessing.cpu_count())
return self.model.generate(
text,
audio_prompt_path=cached_prompt,
temperature=self.config.TEMPERATURE,
top_p=self.config.TOP_P,
repetition_penalty=self.config.REPETITION_PENALTY,
)
wav = await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
blocking_generate
)
return wav
async def generate_and_enhance(self, text: str, audio_prompt_path: str) -> torch.Tensor:
"""Generate and enhance audio"""
wav = await self.generate(text, audio_prompt_path)
# Choose enhancement method based on config
enhance_func = (
AudioProcessor.enhance_audio_simple
if self.config.SIMPLIFY_AUDIO_ENHANCEMENT
else AudioProcessor.enhance_audio_fast
)
# Enhance audio (CPU-bound)
wav = await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
lambda: enhance_func(wav.cpu(), self.sr)
)
return wav
async def generate_to_file(self, text: str, audio_prompt_path: str, output_path: str):
"""Generate audio and save to file"""
wav = await self.generate_and_enhance(text, audio_prompt_path)
# Save to WAV
await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
AudioProcessor.save_tensor_to_wav,
wav,
self.sr,
output_path
)
async def generate_to_buffer(self, text: str, audio_prompt_path: str) -> io.BytesIO:
"""Generate audio and return as WAV buffer"""
wav = await self.generate_and_enhance(text, audio_prompt_path)
# Convert to buffer
buffer = await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
AudioProcessor.tensor_to_wav_buffer,
wav,
self.sr
)
return buffer
def clear_cache(self):
"""Clear audio prompt cache"""
if self.audio_prompt_cache:
self.audio_prompt_cache.clear()
# Example usage
async def main():
"""Example usage of optimized TTS engine"""
config = TTSConfig()
engine = TTSEngine(config)
# Load model once
engine.load_model()
# Generate audio
text = "Halo, ini adalah tes text to speech dalam bahasa Indonesia."
audio_prompt = "path/to/your/voice_sample.wav"
# Generate to file
await engine.generate_to_file(text, audio_prompt, "output.wav")
print("Audio generated successfully!")
# Or generate to buffer
buffer = await engine.generate_to_buffer(text, audio_prompt)
print(f"Audio buffer size: {len(buffer.getvalue())} bytes")
if __name__ == "__main__":
asyncio.run(main())

213
py/tts_util.py Normal file
View File

@@ -0,0 +1,213 @@
import io
import os
import asyncio
from typing import Optional
import torch
import torchaudio as ta
import torchaudio.functional as F
from chatterbox.tts import ChatterboxTTS
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from concurrent.futures import ThreadPoolExecutor
class TTSConfig:
"""Configuration for TTS model and processing"""
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
CHECKPOINT = "t3_cfg.safetensors"
DEVICE = "cpu"
# Generation parameters
TEMPERATURE = 0.65
TOP_P = 0.88
REPETITION_PENALTY = 1.25
# Audio processing
AUDIO_GAIN_DB = 0.8
class AudioProcessor:
"""Audio enhancement utilities"""
@staticmethod
def generate_pink_noise(shape, device):
"""Generate pink noise for audio enhancement"""
white = torch.randn(shape, device=device)
pink = torch.zeros_like(white)
b = torch.zeros(7)
if len(shape) == 1:
for j in range(shape[0]):
w = white[j].item()
b[0] = 0.99886 * b[0] + w * 0.0555179
b[1] = 0.99332 * b[1] + w * 0.0750759
b[2] = 0.96900 * b[2] + w * 0.1538520
b[3] = 0.86650 * b[3] + w * 0.3104856
b[4] = 0.55000 * b[4] + w * 0.5329522
b[5] = -0.7616 * b[5] - w * 0.0168980
pink[j] = (b[0]+b[1]+b[2]+b[3]+b[4]+b[5]+b[6] + w*0.5362) * 0.11
b[6] = w * 0.115926
else:
for i in range(shape[0]):
b = torch.zeros(7)
for j in range(shape[1]):
w = white[i, j].item()
b[0] = 0.99886 * b[0] + w * 0.0555179
b[1] = 0.99332 * b[1] + w * 0.0750759
b[2] = 0.96900 * b[2] + w * 0.1538520
b[3] = 0.86650 * b[3] + w * 0.3104856
b[4] = 0.55000 * b[4] + w * 0.5329522
b[5] = -0.7616 * b[5] - w * 0.0168980
pink[i, j] = (b[0]+b[1]+b[2]+b[3]+b[4]+b[5]+b[6] + w*0.5362) * 0.11
b[6] = w * 0.115926
return pink * 0.1
@staticmethod
def enhance_audio(wav, sr):
"""Apply audio enhancements: normalization, filtering, compression"""
# Normalize
peak = wav.abs().max()
if peak > 0:
wav = wav / (peak + 1e-8) * 0.95
# Apply filters
wav = F.highpass_biquad(wav, sr, cutoff_freq=60)
wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000)
wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7)
wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7)
# Compression
threshold = 0.6
ratio = 2.5
abs_wav = wav.abs()
compressed = wav.clone()
mask = abs_wav > threshold
compressed[mask] = torch.sign(wav[mask]) * (threshold + (abs_wav[mask] - threshold) / ratio)
wav = compressed
wav = torch.tanh(wav * 1.08)
# Add pink noise
wav = wav + AudioProcessor.generate_pink_noise(wav.shape, wav.device) * 0.0003
wav = F.gain(wav, gain_db=TTSConfig.AUDIO_GAIN_DB)
# Final normalization
peak = wav.abs().max()
if peak > 0:
wav = wav / peak * 0.88
return wav
@staticmethod
def save_tensor_to_wav(wav_tensor: torch.Tensor, sr: int, out_wav_path: str):
"""Save a torch tensor to WAV file"""
# Ensure float32 CPU tensor
if wav_tensor.device.type != "cpu":
wav_tensor = wav_tensor.cpu()
if wav_tensor.dtype != torch.float32:
wav_tensor = wav_tensor.type(torch.float32)
# torchaudio.save requires shape [channels, samples]
if wav_tensor.dim() == 1:
wav_out = wav_tensor.unsqueeze(0)
else:
wav_out = wav_tensor
# Save directly as WAV
ta.save(out_wav_path, wav_out, sr, format="wav")
@staticmethod
def tensor_to_wav_buffer(wav_tensor: torch.Tensor, sr: int) -> io.BytesIO:
"""Convert torch tensor to WAV buffer"""
buf = io.BytesIO()
if wav_tensor.dim() == 1:
wav_out = wav_tensor.unsqueeze(0)
else:
wav_out = wav_tensor
ta.save(buf, wav_out, sr, format="wav")
buf.seek(0)
return buf
class TTSEngine:
"""Main TTS engine with model management"""
def __init__(self, config: TTSConfig, thread_pool: ThreadPoolExecutor):
self.config = config
self.thread_pool = thread_pool
self.model = None
self.model_lock = asyncio.Lock()
self.sr = None
def load_model(self):
"""Load the TTS model and checkpoint"""
print("Loading model...")
self.model = ChatterboxTTS.from_pretrained(device=self.config.DEVICE)
ckpt = hf_hub_download(repo_id=self.config.MODEL_REPO, filename=self.config.CHECKPOINT)
state = load_file(ckpt, device=self.config.DEVICE)
self.model.t3.to(self.config.DEVICE).load_state_dict(state)
self.model.t3.eval()
# Disable dropout
for m in self.model.t3.modules():
if hasattr(m, "training"):
m.training = False
if isinstance(m, torch.nn.Dropout):
m.p = 0
self.sr = self.model.sr
print("Model ready.")
async def generate(self, text: str, audio_prompt_path: str) -> torch.Tensor:
"""Generate audio from text with voice prompt"""
async with self.model_lock:
def blocking_generate():
with torch.no_grad():
return self.model.generate(
text,
audio_prompt_path=audio_prompt_path,
temperature=self.config.TEMPERATURE,
top_p=self.config.TOP_P,
repetition_penalty=self.config.REPETITION_PENALTY,
)
wav = await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
blocking_generate
)
return wav
async def generate_and_enhance(self, text: str, audio_prompt_path: str) -> torch.Tensor:
"""Generate and enhance audio"""
wav = await self.generate(text, audio_prompt_path)
# Enhance audio (CPU-bound)
wav = await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
lambda: AudioProcessor.enhance_audio(wav.cpu(), self.sr)
)
return wav
async def generate_to_file(self, text: str, audio_prompt_path: str, output_path: str):
"""Generate audio and save to file"""
wav = await self.generate_and_enhance(text, audio_prompt_path)
# Save to WAV
await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
AudioProcessor.save_tensor_to_wav,
wav,
self.sr,
output_path
)
async def generate_to_buffer(self, text: str, audio_prompt_path: str) -> io.BytesIO:
"""Generate audio and return as WAV buffer"""
wav = await self.generate_and_enhance(text, audio_prompt_path)
# Convert to buffer
return AudioProcessor.tensor_to_wav_buffer(wav, self.sr)