import io import os import asyncio from typing import Optional from functools import lru_cache import torch import torchaudio as ta import torchaudio.functional as F from chatterbox.tts import ChatterboxTTS from huggingface_hub import hf_hub_download from safetensors.torch import load_file from concurrent.futures import ThreadPoolExecutor import multiprocessing class TTSConfig: """Configuration for TTS model and processing""" MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian" CHECKPOINT = "t3_cfg.safetensors" DEVICE = "cpu" # Optimized generation parameters for speed TEMPERATURE = 0.7 TOP_P = 0.9 REPETITION_PENALTY = 1.1 # Audio processing AUDIO_GAIN_DB = 0.8 # Performance settings USE_QUANTIZATION = True USE_TORCH_COMPILE = True SIMPLIFY_AUDIO_ENHANCEMENT = True ENABLE_CACHING = True class AudioProcessor: """Audio enhancement utilities (optimized)""" @staticmethod def generate_pink_noise_fast(shape, device): """Generate pink noise for audio enhancement (vectorized)""" white = torch.randn(shape, device=device) # Fast approximation using multi-scale filtering pink = white * 0.5 # Apply simple averaging for pink-ish spectrum if white.dim() == 1: white_2d = white.unsqueeze(0).unsqueeze(0) else: white_2d = white.unsqueeze(0) if white.dim() == 2 else white # Quick low-pass filtering approximation kernel_size = min(3, white_2d.shape[-1]) if kernel_size >= 2: filtered = torch.nn.functional.avg_pool1d( white_2d, kernel_size=kernel_size, stride=1, padding=kernel_size//2 ) pink += filtered.squeeze(0) * 0.3 if white.dim() == 1 else filtered.squeeze(0) return pink * 0.1 @staticmethod def enhance_audio_fast(wav, sr): """Apply audio enhancements with optimized operations""" with torch.no_grad(): # Normalize peak = wav.abs().max() if peak > 0: wav = wav / (peak + 1e-8) * 0.95 # Apply filters in sequence (no-grad mode for speed) wav = F.highpass_biquad(wav, sr, cutoff_freq=60) wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000) wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7) wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7) # Vectorized compression (faster than loop) threshold = 0.6 ratio = 2.5 abs_wav = wav.abs() mask = abs_wav > threshold wav = torch.where( mask, torch.sign(wav) * (threshold + (abs_wav - threshold) / ratio), wav ) wav = torch.tanh(wav * 1.08) # Add pink noise (fast version) wav = wav + AudioProcessor.generate_pink_noise_fast(wav.shape, wav.device) * 0.0003 wav = F.gain(wav, gain_db=TTSConfig.AUDIO_GAIN_DB) # Final normalization peak = wav.abs().max() if peak > 0: wav = wav / peak * 0.88 return wav @staticmethod def enhance_audio_simple(wav, sr): """Simplified audio enhancement for maximum speed""" with torch.no_grad(): # Simple normalization and tanh saturation peak = wav.abs().max() if peak > 0: wav = wav / (peak + 1e-8) * 0.95 # Basic filtering wav = F.highpass_biquad(wav, sr, cutoff_freq=80) wav = F.lowpass_biquad(wav, sr, cutoff_freq=8000) # Soft clipping wav = torch.tanh(wav * 1.1) # Final normalization peak = wav.abs().max() if peak > 0: wav = wav / peak * 0.9 return wav @staticmethod def save_tensor_to_wav(wav_tensor: torch.Tensor, sr: int, out_wav_path: str): """Save a torch tensor to WAV file""" # Ensure float32 CPU tensor if wav_tensor.device.type != "cpu": wav_tensor = wav_tensor.cpu() if wav_tensor.dtype != torch.float32: wav_tensor = wav_tensor.type(torch.float32) # torchaudio.save requires shape [channels, samples] if wav_tensor.dim() == 1: wav_out = wav_tensor.unsqueeze(0) else: wav_out = wav_tensor # Save directly as WAV ta.save(out_wav_path, wav_out, sr, format="wav") @staticmethod def tensor_to_wav_buffer(wav_tensor: torch.Tensor, sr: int) -> io.BytesIO: """Convert torch tensor to WAV buffer""" buf = io.BytesIO() if wav_tensor.dim() == 1: wav_out = wav_tensor.unsqueeze(0) else: wav_out = wav_tensor ta.save(buf, wav_out, sr, format="wav") buf.seek(0) return buf class TTSEngine: """Main TTS engine with model management (optimized)""" def __init__(self, config: TTSConfig, thread_pool: Optional[ThreadPoolExecutor] = None): self.config = config self.thread_pool = thread_pool or ThreadPoolExecutor( max_workers=multiprocessing.cpu_count() ) self.model = None self.model_lock = asyncio.Lock() self.sr = None self.audio_prompt_cache = {} if config.ENABLE_CACHING else None def load_model(self): """Load the TTS model and checkpoint with optimizations""" print("Loading model...") self.model = ChatterboxTTS.from_pretrained(device=self.config.DEVICE) ckpt = hf_hub_download(repo_id=self.config.MODEL_REPO, filename=self.config.CHECKPOINT) state = load_file(ckpt, device=self.config.DEVICE) self.model.t3.to(self.config.DEVICE).load_state_dict(state) self.model.t3.eval() # Apply quantization for CPU speed if self.config.USE_QUANTIZATION: print("Applying dynamic quantization...") self.model.t3 = torch.quantization.quantize_dynamic( self.model.t3, {torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU}, dtype=torch.qint8 ) # Apply torch.compile if available (PyTorch 2.0+) if self.config.USE_TORCH_COMPILE and hasattr(torch, 'compile'): print("Compiling model with torch.compile...") try: self.model.t3 = torch.compile(self.model.t3, mode="reduce-overhead") except Exception as e: print(f"Torch compile failed: {e}, continuing without compilation") # Disable dropout for inference for m in self.model.t3.modules(): if hasattr(m, "training"): m.training = False if isinstance(m, torch.nn.Dropout): m.p = 0 self.sr = self.model.sr print("Model ready (optimized for CPU).") def _load_audio_prompt(self, audio_prompt_path: str): """Load audio prompt with optional caching""" if self.config.ENABLE_CACHING and audio_prompt_path in self.audio_prompt_cache: return self.audio_prompt_cache[audio_prompt_path] # Load normally # Note: actual loading is done inside model.generate if self.config.ENABLE_CACHING: self.audio_prompt_cache[audio_prompt_path] = audio_prompt_path return audio_prompt_path async def generate(self, text: str, audio_prompt_path: str) -> torch.Tensor: """Generate audio from text with voice prompt""" async with self.model_lock: # Cache audio prompt path cached_prompt = self._load_audio_prompt(audio_prompt_path) def blocking_generate(): with torch.no_grad(): # Set number of threads for CPU inference torch.set_num_threads(multiprocessing.cpu_count()) return self.model.generate( text, audio_prompt_path=cached_prompt, temperature=self.config.TEMPERATURE, top_p=self.config.TOP_P, repetition_penalty=self.config.REPETITION_PENALTY, ) wav = await asyncio.get_event_loop().run_in_executor( self.thread_pool, blocking_generate ) return wav async def generate_and_enhance(self, text: str, audio_prompt_path: str) -> torch.Tensor: """Generate and enhance audio""" wav = await self.generate(text, audio_prompt_path) # Choose enhancement method based on config enhance_func = ( AudioProcessor.enhance_audio_simple if self.config.SIMPLIFY_AUDIO_ENHANCEMENT else AudioProcessor.enhance_audio_fast ) # Enhance audio (CPU-bound) wav = await asyncio.get_event_loop().run_in_executor( self.thread_pool, lambda: enhance_func(wav.cpu(), self.sr) ) return wav async def generate_to_file(self, text: str, audio_prompt_path: str, output_path: str): """Generate audio and save to file""" wav = await self.generate_and_enhance(text, audio_prompt_path) # Save to WAV await asyncio.get_event_loop().run_in_executor( self.thread_pool, AudioProcessor.save_tensor_to_wav, wav, self.sr, output_path ) async def generate_to_buffer(self, text: str, audio_prompt_path: str) -> io.BytesIO: """Generate audio and return as WAV buffer""" wav = await self.generate_and_enhance(text, audio_prompt_path) # Convert to buffer buffer = await asyncio.get_event_loop().run_in_executor( self.thread_pool, AudioProcessor.tensor_to_wav_buffer, wav, self.sr ) return buffer def clear_cache(self): """Clear audio prompt cache""" if self.audio_prompt_cache: self.audio_prompt_cache.clear() # Example usage async def main(): """Example usage of optimized TTS engine""" config = TTSConfig() engine = TTSEngine(config) # Load model once engine.load_model() # Generate audio text = "Halo, ini adalah tes text to speech dalam bahasa Indonesia." audio_prompt = "path/to/your/voice_sample.wav" # Generate to file await engine.generate_to_file(text, audio_prompt, "output.wav") print("Audio generated successfully!") # Or generate to buffer buffer = await engine.generate_to_buffer(text, audio_prompt) print(f"Audio buffer size: {len(buffer.getvalue())} bytes") if __name__ == "__main__": asyncio.run(main())