Files
jenna-tools/py/tts_util.py
bipproduction 822b68c10f tambahannya
2025-12-07 09:00:54 +08:00

213 lines
7.3 KiB
Python

import io
import os
import asyncio
from typing import Optional
import torch
import torchaudio as ta
import torchaudio.functional as F
from chatterbox.tts import ChatterboxTTS
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from concurrent.futures import ThreadPoolExecutor
class TTSConfig:
"""Configuration for TTS model and processing"""
MODEL_REPO = "grandhigh/Chatterbox-TTS-Indonesian"
CHECKPOINT = "t3_cfg.safetensors"
DEVICE = "cpu"
# Generation parameters
TEMPERATURE = 0.65
TOP_P = 0.88
REPETITION_PENALTY = 1.25
# Audio processing
AUDIO_GAIN_DB = 0.8
class AudioProcessor:
"""Audio enhancement utilities"""
@staticmethod
def generate_pink_noise(shape, device):
"""Generate pink noise for audio enhancement"""
white = torch.randn(shape, device=device)
pink = torch.zeros_like(white)
b = torch.zeros(7)
if len(shape) == 1:
for j in range(shape[0]):
w = white[j].item()
b[0] = 0.99886 * b[0] + w * 0.0555179
b[1] = 0.99332 * b[1] + w * 0.0750759
b[2] = 0.96900 * b[2] + w * 0.1538520
b[3] = 0.86650 * b[3] + w * 0.3104856
b[4] = 0.55000 * b[4] + w * 0.5329522
b[5] = -0.7616 * b[5] - w * 0.0168980
pink[j] = (b[0]+b[1]+b[2]+b[3]+b[4]+b[5]+b[6] + w*0.5362) * 0.11
b[6] = w * 0.115926
else:
for i in range(shape[0]):
b = torch.zeros(7)
for j in range(shape[1]):
w = white[i, j].item()
b[0] = 0.99886 * b[0] + w * 0.0555179
b[1] = 0.99332 * b[1] + w * 0.0750759
b[2] = 0.96900 * b[2] + w * 0.1538520
b[3] = 0.86650 * b[3] + w * 0.3104856
b[4] = 0.55000 * b[4] + w * 0.5329522
b[5] = -0.7616 * b[5] - w * 0.0168980
pink[i, j] = (b[0]+b[1]+b[2]+b[3]+b[4]+b[5]+b[6] + w*0.5362) * 0.11
b[6] = w * 0.115926
return pink * 0.1
@staticmethod
def enhance_audio(wav, sr):
"""Apply audio enhancements: normalization, filtering, compression"""
# Normalize
peak = wav.abs().max()
if peak > 0:
wav = wav / (peak + 1e-8) * 0.95
# Apply filters
wav = F.highpass_biquad(wav, sr, cutoff_freq=60)
wav = F.lowpass_biquad(wav, sr, cutoff_freq=10000)
wav = F.bass_biquad(wav, sr, gain=1.5, central_freq=200, Q=0.7)
wav = F.treble_biquad(wav, sr, gain=-1.2, central_freq=6000, Q=0.7)
# Compression
threshold = 0.6
ratio = 2.5
abs_wav = wav.abs()
compressed = wav.clone()
mask = abs_wav > threshold
compressed[mask] = torch.sign(wav[mask]) * (threshold + (abs_wav[mask] - threshold) / ratio)
wav = compressed
wav = torch.tanh(wav * 1.08)
# Add pink noise
wav = wav + AudioProcessor.generate_pink_noise(wav.shape, wav.device) * 0.0003
wav = F.gain(wav, gain_db=TTSConfig.AUDIO_GAIN_DB)
# Final normalization
peak = wav.abs().max()
if peak > 0:
wav = wav / peak * 0.88
return wav
@staticmethod
def save_tensor_to_wav(wav_tensor: torch.Tensor, sr: int, out_wav_path: str):
"""Save a torch tensor to WAV file"""
# Ensure float32 CPU tensor
if wav_tensor.device.type != "cpu":
wav_tensor = wav_tensor.cpu()
if wav_tensor.dtype != torch.float32:
wav_tensor = wav_tensor.type(torch.float32)
# torchaudio.save requires shape [channels, samples]
if wav_tensor.dim() == 1:
wav_out = wav_tensor.unsqueeze(0)
else:
wav_out = wav_tensor
# Save directly as WAV
ta.save(out_wav_path, wav_out, sr, format="wav")
@staticmethod
def tensor_to_wav_buffer(wav_tensor: torch.Tensor, sr: int) -> io.BytesIO:
"""Convert torch tensor to WAV buffer"""
buf = io.BytesIO()
if wav_tensor.dim() == 1:
wav_out = wav_tensor.unsqueeze(0)
else:
wav_out = wav_tensor
ta.save(buf, wav_out, sr, format="wav")
buf.seek(0)
return buf
class TTSEngine:
"""Main TTS engine with model management"""
def __init__(self, config: TTSConfig, thread_pool: ThreadPoolExecutor):
self.config = config
self.thread_pool = thread_pool
self.model = None
self.model_lock = asyncio.Lock()
self.sr = None
def load_model(self):
"""Load the TTS model and checkpoint"""
print("Loading model...")
self.model = ChatterboxTTS.from_pretrained(device=self.config.DEVICE)
ckpt = hf_hub_download(repo_id=self.config.MODEL_REPO, filename=self.config.CHECKPOINT)
state = load_file(ckpt, device=self.config.DEVICE)
self.model.t3.to(self.config.DEVICE).load_state_dict(state)
self.model.t3.eval()
# Disable dropout
for m in self.model.t3.modules():
if hasattr(m, "training"):
m.training = False
if isinstance(m, torch.nn.Dropout):
m.p = 0
self.sr = self.model.sr
print("Model ready.")
async def generate(self, text: str, audio_prompt_path: str) -> torch.Tensor:
"""Generate audio from text with voice prompt"""
async with self.model_lock:
def blocking_generate():
with torch.no_grad():
return self.model.generate(
text,
audio_prompt_path=audio_prompt_path,
temperature=self.config.TEMPERATURE,
top_p=self.config.TOP_P,
repetition_penalty=self.config.REPETITION_PENALTY,
)
wav = await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
blocking_generate
)
return wav
async def generate_and_enhance(self, text: str, audio_prompt_path: str) -> torch.Tensor:
"""Generate and enhance audio"""
wav = await self.generate(text, audio_prompt_path)
# Enhance audio (CPU-bound)
wav = await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
lambda: AudioProcessor.enhance_audio(wav.cpu(), self.sr)
)
return wav
async def generate_to_file(self, text: str, audio_prompt_path: str, output_path: str):
"""Generate audio and save to file"""
wav = await self.generate_and_enhance(text, audio_prompt_path)
# Save to WAV
await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
AudioProcessor.save_tensor_to_wav,
wav,
self.sr,
output_path
)
async def generate_to_buffer(self, text: str, audio_prompt_path: str) -> io.BytesIO:
"""Generate audio and return as WAV buffer"""
wav = await self.generate_and_enhance(text, audio_prompt_path)
# Convert to buffer
return AudioProcessor.tensor_to_wav_buffer(wav, self.sr)