457 lines
13 KiB
Python
457 lines
13 KiB
Python
import os
|
|
import base64
|
|
import uuid
|
|
import asyncio
|
|
import time
|
|
import threading
|
|
from typing import Dict
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from fastapi import FastAPI, UploadFile, File, Form, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from tts_util import TTSConfig, TTSEngine
|
|
|
|
|
|
# ===============================================
|
|
# CONFIG
|
|
# ===============================================
|
|
PROMPT_FOLDER = "prompt_source"
|
|
os.makedirs(PROMPT_FOLDER, exist_ok=True)
|
|
|
|
JOBS_FOLDER = "jobs"
|
|
os.makedirs(JOBS_FOLDER, exist_ok=True)
|
|
|
|
OUTPUT_FOLDER = "output"
|
|
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
|
|
|
|
# Job store
|
|
job_store: Dict[str, dict] = {}
|
|
|
|
# Auto cleanup
|
|
JOB_EXPIRE_SECONDS = 600
|
|
|
|
# FIFO queue and turbo workers
|
|
TURBO_MODE = True
|
|
TURBO_WORKERS = 3
|
|
WORKER_THREADPOOL_MAX = 4
|
|
|
|
# Rate limiting
|
|
RATE_LIMIT_TOKENS = 10
|
|
RATE_LIMIT_WINDOW = 60
|
|
|
|
# Token buckets
|
|
token_buckets: Dict[str, dict] = {}
|
|
token_lock = threading.Lock()
|
|
|
|
# Executor and TTS engine
|
|
thread_pool = ThreadPoolExecutor(max_workers=WORKER_THREADPOOL_MAX)
|
|
config = TTSConfig()
|
|
tts_engine = TTSEngine(config, thread_pool)
|
|
|
|
app = FastAPI(title="Chatterbox TTS Server - Turbo + FIFO + RateLimit + WAV")
|
|
|
|
|
|
# ===============================================
|
|
# RATE LIMIT UTILITIES
|
|
# ===============================================
|
|
def get_client_ip(request: Request) -> str:
|
|
"""Get client IP from request"""
|
|
xff = request.headers.get("x-forwarded-for")
|
|
if xff:
|
|
return xff.split(",")[0].strip()
|
|
if request.client:
|
|
return request.client.host
|
|
return "unknown"
|
|
|
|
|
|
def allow_request_ip(ip: str) -> bool:
|
|
"""Check if request from IP is allowed (token bucket)"""
|
|
now = time.time()
|
|
with token_lock:
|
|
bucket = token_buckets.get(ip)
|
|
if bucket is None:
|
|
token_buckets[ip] = {"tokens": RATE_LIMIT_TOKENS - 1, "last": now}
|
|
return True
|
|
|
|
# Refill tokens
|
|
elapsed = now - bucket["last"]
|
|
refill = (elapsed / RATE_LIMIT_WINDOW) * RATE_LIMIT_TOKENS
|
|
if refill > 0:
|
|
bucket["tokens"] = min(RATE_LIMIT_TOKENS, bucket["tokens"] + refill)
|
|
bucket["last"] = now
|
|
|
|
if bucket["tokens"] >= 1:
|
|
bucket["tokens"] -= 1
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
# ===============================================
|
|
# BACKGROUND WORKERS
|
|
# ===============================================
|
|
job_queue: asyncio.Queue = asyncio.Queue()
|
|
|
|
|
|
async def worker_loop(worker_id: int):
|
|
"""Background worker for processing TTS jobs"""
|
|
while True:
|
|
job_id = await job_queue.get()
|
|
job = job_store.get(job_id)
|
|
if job is None:
|
|
job_queue.task_done()
|
|
continue
|
|
|
|
# Mark as processing
|
|
job["status"] = "processing"
|
|
job["worker"] = worker_id
|
|
job["timestamp"] = time.time()
|
|
|
|
try:
|
|
prompt = job.get("prompt")
|
|
text = job.get("text")
|
|
prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav")
|
|
|
|
if not os.path.exists(prompt_path):
|
|
job["status"] = "error"
|
|
job["error"] = "Prompt tidak ditemukan"
|
|
job["timestamp"] = time.time()
|
|
job_queue.task_done()
|
|
continue
|
|
|
|
# Generate audio
|
|
out_wav = os.path.join(OUTPUT_FOLDER, f"{job_id}.wav")
|
|
await tts_engine.generate_to_file(text, prompt_path, out_wav)
|
|
|
|
job["status"] = "done"
|
|
job["result"] = out_wav
|
|
job["timestamp"] = time.time()
|
|
|
|
except Exception as e:
|
|
job["status"] = "error"
|
|
job["error"] = str(e)
|
|
job["timestamp"] = time.time()
|
|
finally:
|
|
job_queue.task_done()
|
|
|
|
|
|
async def cleanup_worker():
|
|
"""Background cleanup worker for expired jobs"""
|
|
while True:
|
|
now = time.time()
|
|
expired = []
|
|
|
|
for jid, job in list(job_store.items()):
|
|
if (
|
|
job.get("status") == "done"
|
|
and now - job.get("timestamp", 0) > JOB_EXPIRE_SECONDS
|
|
):
|
|
f = job.get("result")
|
|
# if f and os.path.exists(f):
|
|
# try:
|
|
# os.remove(f)
|
|
# except Exception:
|
|
# pass
|
|
expired.append(jid)
|
|
|
|
for jid in expired:
|
|
job_store.pop(jid, None)
|
|
|
|
# Purge stale token buckets
|
|
with token_lock:
|
|
stale_ips = []
|
|
for ip, b in token_buckets.items():
|
|
if now - b.get("last", 0) > RATE_LIMIT_WINDOW * 10:
|
|
stale_ips.append(ip)
|
|
for ip in stale_ips:
|
|
token_buckets.pop(ip, None)
|
|
|
|
await asyncio.sleep(30)
|
|
|
|
|
|
# ===============================================
|
|
# STARTUP
|
|
# ===============================================
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
"""Initialize TTS engine and start background workers"""
|
|
# Load TTS model
|
|
tts_engine.load_model()
|
|
|
|
# Start cleanup worker
|
|
asyncio.create_task(cleanup_worker())
|
|
|
|
# Start turbo workers
|
|
worker_count = TURBO_WORKERS if TURBO_MODE else 1
|
|
for i in range(worker_count):
|
|
asyncio.create_task(worker_loop(i + 1))
|
|
|
|
|
|
# ===============================================
|
|
# PYDANTIC MODELS
|
|
# ===============================================
|
|
class RegisterPromptBase64(BaseModel):
|
|
prompt_name: str
|
|
base64_audio: str
|
|
|
|
|
|
class DeletePrompt(BaseModel):
|
|
prompt_name: str
|
|
|
|
|
|
class RenamePrompt(BaseModel):
|
|
old_name: str
|
|
new_name: str
|
|
|
|
|
|
# ===============================================
|
|
# PROMPT MANAGEMENT ENDPOINTS
|
|
# ===============================================
|
|
@app.post("/register-prompt-base64")
|
|
async def register_prompt_base64(data: RegisterPromptBase64):
|
|
"""Register a voice prompt from base64 audio"""
|
|
filename = f"{data.prompt_name}.wav"
|
|
path = os.path.join(PROMPT_FOLDER, filename)
|
|
try:
|
|
raw = base64.b64decode(data.base64_audio)
|
|
with open(path, "wb") as f:
|
|
f.write(raw)
|
|
return {"status": "ok", "file": filename}
|
|
except Exception as e:
|
|
return JSONResponse(status_code=400, content={"error": str(e)})
|
|
|
|
|
|
@app.post("/register-prompt-file")
|
|
async def register_prompt_file(prompt: UploadFile = File(...), name: str = Form(None)):
|
|
"""Register a voice prompt from uploaded file"""
|
|
prompt_name = name or os.path.splitext(prompt.filename)[0]
|
|
save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
|
|
try:
|
|
with open(save_path, "wb") as f:
|
|
f.write(await prompt.read())
|
|
return {"status": "ok", "file": f"{prompt_name}.wav"}
|
|
except Exception as e:
|
|
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
|
|
@app.get("/list-prompt")
|
|
async def list_prompt():
|
|
"""List all registered voice prompts"""
|
|
lst = [f for f in os.listdir(PROMPT_FOLDER) if f.lower().endswith(".wav")]
|
|
return {
|
|
"count": len(lst),
|
|
"prompts": lst,
|
|
"prompt_names": [os.path.splitext(f)[0] for f in lst],
|
|
}
|
|
|
|
|
|
@app.post("/delete-prompt")
|
|
async def delete_prompt(data: DeletePrompt):
|
|
"""Delete a voice prompt"""
|
|
path = os.path.join(PROMPT_FOLDER, f"{data.prompt_name}.wav")
|
|
if not os.path.exists(path):
|
|
return JSONResponse(
|
|
status_code=404, content={"error": "Prompt tidak ditemukan"}
|
|
)
|
|
try:
|
|
os.remove(path)
|
|
return {"status": "ok", "deleted": f"{data.prompt_name}.wav"}
|
|
except Exception as e:
|
|
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
|
|
@app.post("/rename-prompt")
|
|
async def rename_prompt(data: RenamePrompt):
|
|
"""Rename a voice prompt"""
|
|
old = os.path.join(PROMPT_FOLDER, f"{data.old_name}.wav")
|
|
new = os.path.join(PROMPT_FOLDER, f"{data.new_name}.wav")
|
|
if not os.path.exists(old):
|
|
return JSONResponse(
|
|
status_code=404, content={"error": "Prompt lama tidak ditemukan"}
|
|
)
|
|
if os.path.exists(new):
|
|
return JSONResponse(
|
|
status_code=400, content={"error": "Nama baru sudah digunakan"}
|
|
)
|
|
try:
|
|
os.rename(old, new)
|
|
return {"status": "ok", "from": data.old_name, "to": data.new_name}
|
|
except Exception as e:
|
|
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
|
|
@app.post("/tts-async")
|
|
async def tts_async(request: Request, text: str = Form(...), prompt: str = Form(...)):
|
|
"""Asynchronous TTS - enqueue job and return job_id"""
|
|
client_ip = get_client_ip(request)
|
|
if not allow_request_ip(client_ip):
|
|
return JSONResponse(status_code=429, content={"error": "rate limit exceeded"})
|
|
|
|
job_id = str(uuid.uuid4())
|
|
job_store[job_id] = {
|
|
"status": "pending",
|
|
"timestamp": time.time(),
|
|
"prompt": prompt,
|
|
"text": text,
|
|
"client_ip": client_ip,
|
|
}
|
|
|
|
# Enqueue (FIFO)
|
|
await job_queue.put(job_id)
|
|
|
|
return {"status": "queued", "job_id": job_id, "check": f"/result/{job_id}"}
|
|
|
|
|
|
@app.get("/result/{job_id}")
|
|
async def tts_result(job_id: str):
|
|
"""Get result of async TTS job"""
|
|
job = job_store.get(job_id)
|
|
if not job:
|
|
return JSONResponse(
|
|
status_code=404, content={"error": "Job ID tidak ditemukan"}
|
|
)
|
|
|
|
# Still processing
|
|
if job["status"] in ("pending", "processing"):
|
|
return {
|
|
"status": job["status"],
|
|
"job_id": job_id,
|
|
"worker": job.get("worker"),
|
|
"timestamp": job.get("timestamp"),
|
|
}
|
|
|
|
# Error
|
|
if job["status"] == "error":
|
|
return job
|
|
|
|
# Done - return file
|
|
result_path = job.get("result")
|
|
if not result_path or not os.path.exists(result_path):
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"status": "error", "error": "File hasil tidak ditemukan"},
|
|
)
|
|
|
|
return JSONResponse(
|
|
status_code=200,
|
|
content={"status": "done", "job_id": job_id, "file": result_path},
|
|
)
|
|
|
|
@app.get("/list-file")
|
|
async def list_file():
|
|
"""List all files inside OUTPUT_FOLDER"""
|
|
try:
|
|
files = [
|
|
f for f in os.listdir(OUTPUT_FOLDER)
|
|
if os.path.isfile(os.path.join(OUTPUT_FOLDER, f))
|
|
]
|
|
|
|
# Hanya file WAV (sesuai output engine)
|
|
wav_files = [f for f in files if f.lower().endswith(".wav")]
|
|
|
|
# Include metadata timestamp dari job_store
|
|
detailed = []
|
|
for f in wav_files:
|
|
full_path = os.path.join(OUTPUT_FOLDER, f)
|
|
size = os.path.getsize(full_path)
|
|
|
|
# Cari job yang terkait (jika ada)
|
|
related_job = None
|
|
for jid, job in job_store.items():
|
|
if job.get("result") == full_path:
|
|
related_job = {
|
|
"job_id": jid,
|
|
"status": job.get("status"),
|
|
"timestamp": job.get("timestamp"),
|
|
"prompt": job.get("prompt"),
|
|
}
|
|
break
|
|
|
|
detailed.append(
|
|
{
|
|
"file": f,
|
|
"size_bytes": size,
|
|
"path": full_path,
|
|
"job": related_job,
|
|
}
|
|
)
|
|
|
|
return {
|
|
"count": len(wav_files),
|
|
"files": detailed,
|
|
}
|
|
|
|
except Exception as e:
|
|
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
|
|
async def iterfile(file_path):
|
|
with open(file_path, "rb") as f:
|
|
chunk = f.read(4096)
|
|
while chunk:
|
|
yield chunk
|
|
chunk = f.read(4096)
|
|
|
|
|
|
@app.get("/file/{file_name}")
|
|
async def get_output_file(file_name: str):
|
|
if not file_name.endswith(".wav"):
|
|
file_name = f"{file_name}.wav"
|
|
file_path = os.path.join(OUTPUT_FOLDER, file_name)
|
|
if not os.path.exists(file_path):
|
|
return JSONResponse(status_code=404, content={"error": "File tidak ditemukan"})
|
|
|
|
return StreamingResponse(
|
|
iterfile(file_path),
|
|
media_type="audio/wav",
|
|
headers={"Content-Disposition": f"attachment; filename={file_name}"},
|
|
)
|
|
|
|
|
|
# ===============================================
|
|
# FILE MANAGEMENT ENDPOINTS
|
|
# ===============================================
|
|
@app.delete("/rm/{filename}")
|
|
async def remove_file(filename: str):
|
|
"""Delete a single output file"""
|
|
if not filename.endswith(".wav"):
|
|
filename = f"{filename}.wav"
|
|
path = os.path.join(OUTPUT_FOLDER, filename)
|
|
|
|
if not os.path.exists(path):
|
|
return JSONResponse(status_code=404, content={"error": "File tidak ditemukan"})
|
|
|
|
try:
|
|
os.remove(path)
|
|
# Remove from job_store
|
|
for jid, job in list(job_store.items()):
|
|
if job.get("result") == path:
|
|
job_store.pop(jid, None)
|
|
return {"status": "ok", "deleted": filename}
|
|
except Exception as e:
|
|
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
@app.post("/cleanup")
|
|
async def manual_cleanup():
|
|
"""Manual cleanup - remove all output files and clear done/error jobs"""
|
|
removed = []
|
|
for f in os.listdir(OUTPUT_FOLDER):
|
|
fp = os.path.join(OUTPUT_FOLDER, f)
|
|
if os.path.isfile(fp):
|
|
try:
|
|
os.remove(fp)
|
|
removed.append(f)
|
|
except:
|
|
pass
|
|
|
|
# Clear done/error jobs
|
|
cleared = []
|
|
for jid in list(job_store.keys()):
|
|
if job_store[jid]["status"] in ("done", "error"):
|
|
job_store.pop(jid, None)
|
|
cleared.append(jid)
|
|
|
|
return {"status": "ok", "removed_files": removed, "jobs_cleared": cleared}
|