Files
jenna-tools/py/main.py
bipproduction 822b68c10f tambahannya
2025-12-07 09:00:54 +08:00

457 lines
13 KiB
Python

import os
import base64
import uuid
import asyncio
import time
import threading
from typing import Dict
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, UploadFile, File, Form, Request
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from tts_util import TTSConfig, TTSEngine
# ===============================================
# CONFIG
# ===============================================
PROMPT_FOLDER = "prompt_source"
os.makedirs(PROMPT_FOLDER, exist_ok=True)
JOBS_FOLDER = "jobs"
os.makedirs(JOBS_FOLDER, exist_ok=True)
OUTPUT_FOLDER = "output"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
# Job store
job_store: Dict[str, dict] = {}
# Auto cleanup
JOB_EXPIRE_SECONDS = 600
# FIFO queue and turbo workers
TURBO_MODE = True
TURBO_WORKERS = 3
WORKER_THREADPOOL_MAX = 4
# Rate limiting
RATE_LIMIT_TOKENS = 10
RATE_LIMIT_WINDOW = 60
# Token buckets
token_buckets: Dict[str, dict] = {}
token_lock = threading.Lock()
# Executor and TTS engine
thread_pool = ThreadPoolExecutor(max_workers=WORKER_THREADPOOL_MAX)
config = TTSConfig()
tts_engine = TTSEngine(config, thread_pool)
app = FastAPI(title="Chatterbox TTS Server - Turbo + FIFO + RateLimit + WAV")
# ===============================================
# RATE LIMIT UTILITIES
# ===============================================
def get_client_ip(request: Request) -> str:
"""Get client IP from request"""
xff = request.headers.get("x-forwarded-for")
if xff:
return xff.split(",")[0].strip()
if request.client:
return request.client.host
return "unknown"
def allow_request_ip(ip: str) -> bool:
"""Check if request from IP is allowed (token bucket)"""
now = time.time()
with token_lock:
bucket = token_buckets.get(ip)
if bucket is None:
token_buckets[ip] = {"tokens": RATE_LIMIT_TOKENS - 1, "last": now}
return True
# Refill tokens
elapsed = now - bucket["last"]
refill = (elapsed / RATE_LIMIT_WINDOW) * RATE_LIMIT_TOKENS
if refill > 0:
bucket["tokens"] = min(RATE_LIMIT_TOKENS, bucket["tokens"] + refill)
bucket["last"] = now
if bucket["tokens"] >= 1:
bucket["tokens"] -= 1
return True
else:
return False
# ===============================================
# BACKGROUND WORKERS
# ===============================================
job_queue: asyncio.Queue = asyncio.Queue()
async def worker_loop(worker_id: int):
"""Background worker for processing TTS jobs"""
while True:
job_id = await job_queue.get()
job = job_store.get(job_id)
if job is None:
job_queue.task_done()
continue
# Mark as processing
job["status"] = "processing"
job["worker"] = worker_id
job["timestamp"] = time.time()
try:
prompt = job.get("prompt")
text = job.get("text")
prompt_path = os.path.join(PROMPT_FOLDER, f"{prompt}.wav")
if not os.path.exists(prompt_path):
job["status"] = "error"
job["error"] = "Prompt tidak ditemukan"
job["timestamp"] = time.time()
job_queue.task_done()
continue
# Generate audio
out_wav = os.path.join(OUTPUT_FOLDER, f"{job_id}.wav")
await tts_engine.generate_to_file(text, prompt_path, out_wav)
job["status"] = "done"
job["result"] = out_wav
job["timestamp"] = time.time()
except Exception as e:
job["status"] = "error"
job["error"] = str(e)
job["timestamp"] = time.time()
finally:
job_queue.task_done()
async def cleanup_worker():
"""Background cleanup worker for expired jobs"""
while True:
now = time.time()
expired = []
for jid, job in list(job_store.items()):
if (
job.get("status") == "done"
and now - job.get("timestamp", 0) > JOB_EXPIRE_SECONDS
):
f = job.get("result")
# if f and os.path.exists(f):
# try:
# os.remove(f)
# except Exception:
# pass
expired.append(jid)
for jid in expired:
job_store.pop(jid, None)
# Purge stale token buckets
with token_lock:
stale_ips = []
for ip, b in token_buckets.items():
if now - b.get("last", 0) > RATE_LIMIT_WINDOW * 10:
stale_ips.append(ip)
for ip in stale_ips:
token_buckets.pop(ip, None)
await asyncio.sleep(30)
# ===============================================
# STARTUP
# ===============================================
@app.on_event("startup")
async def startup():
"""Initialize TTS engine and start background workers"""
# Load TTS model
tts_engine.load_model()
# Start cleanup worker
asyncio.create_task(cleanup_worker())
# Start turbo workers
worker_count = TURBO_WORKERS if TURBO_MODE else 1
for i in range(worker_count):
asyncio.create_task(worker_loop(i + 1))
# ===============================================
# PYDANTIC MODELS
# ===============================================
class RegisterPromptBase64(BaseModel):
prompt_name: str
base64_audio: str
class DeletePrompt(BaseModel):
prompt_name: str
class RenamePrompt(BaseModel):
old_name: str
new_name: str
# ===============================================
# PROMPT MANAGEMENT ENDPOINTS
# ===============================================
@app.post("/register-prompt-base64")
async def register_prompt_base64(data: RegisterPromptBase64):
"""Register a voice prompt from base64 audio"""
filename = f"{data.prompt_name}.wav"
path = os.path.join(PROMPT_FOLDER, filename)
try:
raw = base64.b64decode(data.base64_audio)
with open(path, "wb") as f:
f.write(raw)
return {"status": "ok", "file": filename}
except Exception as e:
return JSONResponse(status_code=400, content={"error": str(e)})
@app.post("/register-prompt-file")
async def register_prompt_file(prompt: UploadFile = File(...), name: str = Form(None)):
"""Register a voice prompt from uploaded file"""
prompt_name = name or os.path.splitext(prompt.filename)[0]
save_path = os.path.join(PROMPT_FOLDER, f"{prompt_name}.wav")
try:
with open(save_path, "wb") as f:
f.write(await prompt.read())
return {"status": "ok", "file": f"{prompt_name}.wav"}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.get("/list-prompt")
async def list_prompt():
"""List all registered voice prompts"""
lst = [f for f in os.listdir(PROMPT_FOLDER) if f.lower().endswith(".wav")]
return {
"count": len(lst),
"prompts": lst,
"prompt_names": [os.path.splitext(f)[0] for f in lst],
}
@app.post("/delete-prompt")
async def delete_prompt(data: DeletePrompt):
"""Delete a voice prompt"""
path = os.path.join(PROMPT_FOLDER, f"{data.prompt_name}.wav")
if not os.path.exists(path):
return JSONResponse(
status_code=404, content={"error": "Prompt tidak ditemukan"}
)
try:
os.remove(path)
return {"status": "ok", "deleted": f"{data.prompt_name}.wav"}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/rename-prompt")
async def rename_prompt(data: RenamePrompt):
"""Rename a voice prompt"""
old = os.path.join(PROMPT_FOLDER, f"{data.old_name}.wav")
new = os.path.join(PROMPT_FOLDER, f"{data.new_name}.wav")
if not os.path.exists(old):
return JSONResponse(
status_code=404, content={"error": "Prompt lama tidak ditemukan"}
)
if os.path.exists(new):
return JSONResponse(
status_code=400, content={"error": "Nama baru sudah digunakan"}
)
try:
os.rename(old, new)
return {"status": "ok", "from": data.old_name, "to": data.new_name}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/tts-async")
async def tts_async(request: Request, text: str = Form(...), prompt: str = Form(...)):
"""Asynchronous TTS - enqueue job and return job_id"""
client_ip = get_client_ip(request)
if not allow_request_ip(client_ip):
return JSONResponse(status_code=429, content={"error": "rate limit exceeded"})
job_id = str(uuid.uuid4())
job_store[job_id] = {
"status": "pending",
"timestamp": time.time(),
"prompt": prompt,
"text": text,
"client_ip": client_ip,
}
# Enqueue (FIFO)
await job_queue.put(job_id)
return {"status": "queued", "job_id": job_id, "check": f"/result/{job_id}"}
@app.get("/result/{job_id}")
async def tts_result(job_id: str):
"""Get result of async TTS job"""
job = job_store.get(job_id)
if not job:
return JSONResponse(
status_code=404, content={"error": "Job ID tidak ditemukan"}
)
# Still processing
if job["status"] in ("pending", "processing"):
return {
"status": job["status"],
"job_id": job_id,
"worker": job.get("worker"),
"timestamp": job.get("timestamp"),
}
# Error
if job["status"] == "error":
return job
# Done - return file
result_path = job.get("result")
if not result_path or not os.path.exists(result_path):
return JSONResponse(
status_code=500,
content={"status": "error", "error": "File hasil tidak ditemukan"},
)
return JSONResponse(
status_code=200,
content={"status": "done", "job_id": job_id, "file": result_path},
)
@app.get("/list-file")
async def list_file():
"""List all files inside OUTPUT_FOLDER"""
try:
files = [
f for f in os.listdir(OUTPUT_FOLDER)
if os.path.isfile(os.path.join(OUTPUT_FOLDER, f))
]
# Hanya file WAV (sesuai output engine)
wav_files = [f for f in files if f.lower().endswith(".wav")]
# Include metadata timestamp dari job_store
detailed = []
for f in wav_files:
full_path = os.path.join(OUTPUT_FOLDER, f)
size = os.path.getsize(full_path)
# Cari job yang terkait (jika ada)
related_job = None
for jid, job in job_store.items():
if job.get("result") == full_path:
related_job = {
"job_id": jid,
"status": job.get("status"),
"timestamp": job.get("timestamp"),
"prompt": job.get("prompt"),
}
break
detailed.append(
{
"file": f,
"size_bytes": size,
"path": full_path,
"job": related_job,
}
)
return {
"count": len(wav_files),
"files": detailed,
}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
async def iterfile(file_path):
with open(file_path, "rb") as f:
chunk = f.read(4096)
while chunk:
yield chunk
chunk = f.read(4096)
@app.get("/file/{file_name}")
async def get_output_file(file_name: str):
if not file_name.endswith(".wav"):
file_name = f"{file_name}.wav"
file_path = os.path.join(OUTPUT_FOLDER, file_name)
if not os.path.exists(file_path):
return JSONResponse(status_code=404, content={"error": "File tidak ditemukan"})
return StreamingResponse(
iterfile(file_path),
media_type="audio/wav",
headers={"Content-Disposition": f"attachment; filename={file_name}"},
)
# ===============================================
# FILE MANAGEMENT ENDPOINTS
# ===============================================
@app.delete("/rm/{filename}")
async def remove_file(filename: str):
"""Delete a single output file"""
if not filename.endswith(".wav"):
filename = f"{filename}.wav"
path = os.path.join(OUTPUT_FOLDER, filename)
if not os.path.exists(path):
return JSONResponse(status_code=404, content={"error": "File tidak ditemukan"})
try:
os.remove(path)
# Remove from job_store
for jid, job in list(job_store.items()):
if job.get("result") == path:
job_store.pop(jid, None)
return {"status": "ok", "deleted": filename}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/cleanup")
async def manual_cleanup():
"""Manual cleanup - remove all output files and clear done/error jobs"""
removed = []
for f in os.listdir(OUTPUT_FOLDER):
fp = os.path.join(OUTPUT_FOLDER, f)
if os.path.isfile(fp):
try:
os.remove(fp)
removed.append(f)
except:
pass
# Clear done/error jobs
cleared = []
for jid in list(job_store.keys()):
if job_store[jid]["status"] in ("done", "error"):
job_store.pop(jid, None)
cleared.append(jid)
return {"status": "ok", "removed_files": removed, "jobs_cleared": cleared}