322 lines
11 KiB
Python
322 lines
11 KiB
Python
import io
|
|
import os
|
|
import asyncio
|
|
from typing import Optional
|
|
from functools import lru_cache
|
|
|
|
import torch
|
|
import torchaudio as ta
|
|
import torchaudio.functional as F
|
|
from chatterbox.tts import ChatterboxTTS
|
|
from huggingface_hub import hf_hub_download
|
|
from safetensors.torch import load_file
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import multiprocessing
|
|
|
|
|
|
class TTSConfig:
|
|
"""Configuration for TTS model and processing"""
|
|
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
|
|
CHECKPOINT = "t3_cfg.safetensors"
|
|
DEVICE = "cpu"
|
|
|
|
# Optimized generation parameters for speed
|
|
TEMPERATURE = 0.7
|
|
TOP_P = 0.9
|
|
REPETITION_PENALTY = 1.1
|
|
|
|
# Audio processing
|
|
AUDIO_GAIN_DB = 0.8
|
|
|
|
# Performance settings
|
|
USE_QUANTIZATION = True
|
|
USE_TORCH_COMPILE = True
|
|
SIMPLIFY_AUDIO_ENHANCEMENT = True
|
|
ENABLE_CACHING = True
|
|
|
|
|
|
class AudioProcessor:
|
|
"""Audio enhancement utilities (optimized)"""
|
|
|
|
@staticmethod
|
|
def generate_pink_noise_fast(shape, device):
|
|
"""Generate pink noise for audio enhancement (vectorized)"""
|
|
white = torch.randn(shape, device=device)
|
|
|
|
# Fast approximation using multi-scale filtering
|
|
pink = white * 0.5
|
|
|
|
# Apply simple averaging for pink-ish spectrum
|
|
if white.dim() == 1:
|
|
white_2d = white.unsqueeze(0).unsqueeze(0)
|
|
else:
|
|
white_2d = white.unsqueeze(0) if white.dim() == 2 else white
|
|
|
|
# Quick low-pass filtering approximation
|
|
kernel_size = min(3, white_2d.shape[-1])
|
|
if kernel_size >= 2:
|
|
filtered = torch.nn.functional.avg_pool1d(
|
|
white_2d,
|
|
kernel_size=kernel_size,
|
|
stride=1,
|
|
padding=kernel_size//2
|
|
)
|
|
pink += filtered.squeeze(0) * 0.3 if white.dim() == 1 else filtered.squeeze(0)
|
|
|
|
return pink * 0.1
|
|
|
|
@staticmethod
|
|
def enhance_audio_fast(wav, sr):
|
|
"""Apply audio enhancements with optimized operations"""
|
|
with torch.no_grad():
|
|
# Normalize
|
|
peak = wav.abs().max()
|
|
if peak > 0:
|
|
wav = wav / (peak + 1e-8) * 0.95
|
|
|
|
# Apply filters in sequence (no-grad mode for speed)
|
|
wav = F.highpass_biquad(wav, sr, cutoff_freq=60)
|
|
wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000)
|
|
wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7)
|
|
wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7)
|
|
|
|
# Vectorized compression (faster than loop)
|
|
threshold = 0.6
|
|
ratio = 2.5
|
|
abs_wav = wav.abs()
|
|
mask = abs_wav > threshold
|
|
wav = torch.where(
|
|
mask,
|
|
torch.sign(wav) * (threshold + (abs_wav - threshold) / ratio),
|
|
wav
|
|
)
|
|
|
|
wav = torch.tanh(wav * 1.08)
|
|
|
|
# Add pink noise (fast version)
|
|
wav = wav + AudioProcessor.generate_pink_noise_fast(wav.shape, wav.device) * 0.0003
|
|
wav = F.gain(wav, gain_db=TTSConfig.AUDIO_GAIN_DB)
|
|
|
|
# Final normalization
|
|
peak = wav.abs().max()
|
|
if peak > 0:
|
|
wav = wav / peak * 0.88
|
|
|
|
return wav
|
|
|
|
@staticmethod
|
|
def enhance_audio_simple(wav, sr):
|
|
"""Simplified audio enhancement for maximum speed"""
|
|
with torch.no_grad():
|
|
# Simple normalization and tanh saturation
|
|
peak = wav.abs().max()
|
|
if peak > 0:
|
|
wav = wav / (peak + 1e-8) * 0.95
|
|
|
|
# Basic filtering
|
|
wav = F.highpass_biquad(wav, sr, cutoff_freq=80)
|
|
wav = F.lowpass_biquad(wav, sr, cutoff_freq=8000)
|
|
|
|
# Soft clipping
|
|
wav = torch.tanh(wav * 1.1)
|
|
|
|
# Final normalization
|
|
peak = wav.abs().max()
|
|
if peak > 0:
|
|
wav = wav / peak * 0.9
|
|
|
|
return wav
|
|
|
|
@staticmethod
|
|
def save_tensor_to_wav(wav_tensor: torch.Tensor, sr: int, out_wav_path: str):
|
|
"""Save a torch tensor to WAV file"""
|
|
# Ensure float32 CPU tensor
|
|
if wav_tensor.device.type != "cpu":
|
|
wav_tensor = wav_tensor.cpu()
|
|
if wav_tensor.dtype != torch.float32:
|
|
wav_tensor = wav_tensor.type(torch.float32)
|
|
|
|
# torchaudio.save requires shape [channels, samples]
|
|
if wav_tensor.dim() == 1:
|
|
wav_out = wav_tensor.unsqueeze(0)
|
|
else:
|
|
wav_out = wav_tensor
|
|
|
|
# Save directly as WAV
|
|
ta.save(out_wav_path, wav_out, sr, format="wav")
|
|
|
|
@staticmethod
|
|
def tensor_to_wav_buffer(wav_tensor: torch.Tensor, sr: int) -> io.BytesIO:
|
|
"""Convert torch tensor to WAV buffer"""
|
|
buf = io.BytesIO()
|
|
if wav_tensor.dim() == 1:
|
|
wav_out = wav_tensor.unsqueeze(0)
|
|
else:
|
|
wav_out = wav_tensor
|
|
ta.save(buf, wav_out, sr, format="wav")
|
|
buf.seek(0)
|
|
return buf
|
|
|
|
|
|
class TTSEngine:
|
|
"""Main TTS engine with model management (optimized)"""
|
|
|
|
def __init__(self, config: TTSConfig, thread_pool: Optional[ThreadPoolExecutor] = None):
|
|
self.config = config
|
|
self.thread_pool = thread_pool or ThreadPoolExecutor(
|
|
max_workers=multiprocessing.cpu_count()
|
|
)
|
|
self.model = None
|
|
self.model_lock = asyncio.Lock()
|
|
self.sr = None
|
|
self.audio_prompt_cache = {} if config.ENABLE_CACHING else None
|
|
|
|
def load_model(self):
|
|
"""Load the TTS model and checkpoint with optimizations"""
|
|
print("Loading model...")
|
|
self.model = ChatterboxTTS.from_pretrained(device=self.config.DEVICE)
|
|
ckpt = hf_hub_download(repo_id=self.config.MODEL_REPO, filename=self.config.CHECKPOINT)
|
|
state = load_file(ckpt, device=self.config.DEVICE)
|
|
|
|
self.model.t3.to(self.config.DEVICE).load_state_dict(state)
|
|
self.model.t3.eval()
|
|
|
|
# Apply quantization for CPU speed
|
|
if self.config.USE_QUANTIZATION:
|
|
print("Applying dynamic quantization...")
|
|
self.model.t3 = torch.quantization.quantize_dynamic(
|
|
self.model.t3,
|
|
{torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU},
|
|
dtype=torch.qint8
|
|
)
|
|
|
|
# Apply torch.compile if available (PyTorch 2.0+)
|
|
if self.config.USE_TORCH_COMPILE and hasattr(torch, 'compile'):
|
|
print("Compiling model with torch.compile...")
|
|
try:
|
|
self.model.t3 = torch.compile(self.model.t3, mode="reduce-overhead")
|
|
except Exception as e:
|
|
print(f"Torch compile failed: {e}, continuing without compilation")
|
|
|
|
# Disable dropout for inference
|
|
for m in self.model.t3.modules():
|
|
if hasattr(m, "training"):
|
|
m.training = False
|
|
if isinstance(m, torch.nn.Dropout):
|
|
m.p = 0
|
|
|
|
self.sr = self.model.sr
|
|
print("Model ready (optimized for CPU).")
|
|
|
|
def _load_audio_prompt(self, audio_prompt_path: str):
|
|
"""Load audio prompt with optional caching"""
|
|
if self.config.ENABLE_CACHING and audio_prompt_path in self.audio_prompt_cache:
|
|
return self.audio_prompt_cache[audio_prompt_path]
|
|
|
|
# Load normally
|
|
# Note: actual loading is done inside model.generate
|
|
if self.config.ENABLE_CACHING:
|
|
self.audio_prompt_cache[audio_prompt_path] = audio_prompt_path
|
|
|
|
return audio_prompt_path
|
|
|
|
async def generate(self, text: str, audio_prompt_path: str) -> torch.Tensor:
|
|
"""Generate audio from text with voice prompt"""
|
|
async with self.model_lock:
|
|
# Cache audio prompt path
|
|
cached_prompt = self._load_audio_prompt(audio_prompt_path)
|
|
|
|
def blocking_generate():
|
|
with torch.no_grad():
|
|
# Set number of threads for CPU inference
|
|
torch.set_num_threads(multiprocessing.cpu_count())
|
|
|
|
return self.model.generate(
|
|
text,
|
|
audio_prompt_path=cached_prompt,
|
|
temperature=self.config.TEMPERATURE,
|
|
top_p=self.config.TOP_P,
|
|
repetition_penalty=self.config.REPETITION_PENALTY,
|
|
)
|
|
|
|
wav = await asyncio.get_event_loop().run_in_executor(
|
|
self.thread_pool,
|
|
blocking_generate
|
|
)
|
|
return wav
|
|
|
|
async def generate_and_enhance(self, text: str, audio_prompt_path: str) -> torch.Tensor:
|
|
"""Generate and enhance audio"""
|
|
wav = await self.generate(text, audio_prompt_path)
|
|
|
|
# Choose enhancement method based on config
|
|
enhance_func = (
|
|
AudioProcessor.enhance_audio_simple
|
|
if self.config.SIMPLIFY_AUDIO_ENHANCEMENT
|
|
else AudioProcessor.enhance_audio_fast
|
|
)
|
|
|
|
# Enhance audio (CPU-bound)
|
|
wav = await asyncio.get_event_loop().run_in_executor(
|
|
self.thread_pool,
|
|
lambda: enhance_func(wav.cpu(), self.sr)
|
|
)
|
|
|
|
return wav
|
|
|
|
async def generate_to_file(self, text: str, audio_prompt_path: str, output_path: str):
|
|
"""Generate audio and save to file"""
|
|
wav = await self.generate_and_enhance(text, audio_prompt_path)
|
|
|
|
# Save to WAV
|
|
await asyncio.get_event_loop().run_in_executor(
|
|
self.thread_pool,
|
|
AudioProcessor.save_tensor_to_wav,
|
|
wav,
|
|
self.sr,
|
|
output_path
|
|
)
|
|
|
|
async def generate_to_buffer(self, text: str, audio_prompt_path: str) -> io.BytesIO:
|
|
"""Generate audio and return as WAV buffer"""
|
|
wav = await self.generate_and_enhance(text, audio_prompt_path)
|
|
|
|
# Convert to buffer
|
|
buffer = await asyncio.get_event_loop().run_in_executor(
|
|
self.thread_pool,
|
|
AudioProcessor.tensor_to_wav_buffer,
|
|
wav,
|
|
self.sr
|
|
)
|
|
return buffer
|
|
|
|
def clear_cache(self):
|
|
"""Clear audio prompt cache"""
|
|
if self.audio_prompt_cache:
|
|
self.audio_prompt_cache.clear()
|
|
|
|
|
|
# Example usage
|
|
async def main():
|
|
"""Example usage of optimized TTS engine"""
|
|
config = TTSConfig()
|
|
engine = TTSEngine(config)
|
|
|
|
# Load model once
|
|
engine.load_model()
|
|
|
|
# Generate audio
|
|
text = "Halo, ini adalah tes text to speech dalam bahasa Indonesia."
|
|
audio_prompt = "path/to/your/voice_sample.wav"
|
|
|
|
# Generate to file
|
|
await engine.generate_to_file(text, audio_prompt, "output.wav")
|
|
print("Audio generated successfully!")
|
|
|
|
# Or generate to buffer
|
|
buffer = await engine.generate_to_buffer(text, audio_prompt)
|
|
print(f"Audio buffer size: {len(buffer.getvalue())} bytes")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |