feat(api): 新增FastAPI服务接口及完善TTS配置管理
实现FastAPI服务接口,支持播客生成任务提交、状态查询和音频下载功能 重构TTS配置管理,统一处理不同TTS服务商的API URL配置 更新README文档,添加API使用说明和项目徽章 添加定时清理输出目录功能,优化资源管理
This commit is contained in:
287
main.py
Normal file
287
main.py
Normal file
@@ -0,0 +1,287 @@
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends, Form, Header
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from typing import Optional, Dict
|
||||
import uuid
|
||||
import asyncio
|
||||
from starlette.background import BackgroundTasks
|
||||
from uuid import UUID
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
from enum import Enum
|
||||
import shutil
|
||||
import schedule
|
||||
import threading
|
||||
|
||||
from podcast_generator import generate_podcast_audio
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Global flag to signal the scheduler thread to stop
|
||||
stop_scheduler_event = threading.Event()
|
||||
|
||||
# Global reference for the scheduler thread
|
||||
scheduler_thread = None
|
||||
|
||||
# Global configuration
|
||||
output_dir = "output"
|
||||
|
||||
# Define a function to clean the output directory
|
||||
def clean_output_directory():
|
||||
"""Removes files from the output directory that are older than 30 minutes."""
|
||||
print(f"Cleaning output directory: {output_dir}")
|
||||
now = time.time()
|
||||
# 30 minutes in seconds
|
||||
threshold = 30 * 60
|
||||
|
||||
for filename in os.listdir(output_dir):
|
||||
file_path = os.path.join(output_dir, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
# Get last modification time
|
||||
if now - os.path.getmtime(file_path) > threshold:
|
||||
os.unlink(file_path)
|
||||
print(f"Deleted old file: {file_path}")
|
||||
elif os.path.isdir(file_path):
|
||||
# Optionally, recursively delete old subdirectories or files within them
|
||||
# For now, just skip directories
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||
|
||||
# In-memory store for task results
|
||||
# {task_id: {"auth_id": auth_id, "status": TaskStatus, "result": any, "timestamp": float}}
|
||||
task_results: Dict[str, Dict[UUID, Dict]] = {}
|
||||
|
||||
# Configuration for signature verification
|
||||
SECRET_KEY = os.getenv("PODCAST_API_SECRET_KEY", "your-super-secret-key") # Change this in production!
|
||||
# Define a mapping from tts_provider names to their config file paths
|
||||
tts_provider_map = {
|
||||
"index-tts": "config/index-tts.json",
|
||||
"doubao-tts": "config/doubao-tts.json",
|
||||
"edge-tts": "config/edge-tts.json",
|
||||
"fish-audio": "config/fish-audio.json",
|
||||
"gemini-tts": "config/gemini-tts.json",
|
||||
"minimax": "config/minimax.json",
|
||||
}
|
||||
|
||||
async def get_auth_id(x_auth_id: str = Header(..., alias="X-Auth-Id")):
|
||||
"""
|
||||
Dependency to get X-Auth-Id from headers.
|
||||
"""
|
||||
if not x_auth_id:
|
||||
raise HTTPException(status_code=400, detail="Missing X-Auth-Id header.")
|
||||
return x_auth_id
|
||||
|
||||
async def verify_signature(request: Request):
|
||||
"""
|
||||
Verify the 'sign' parameter in the request headers or query parameters.
|
||||
Expected signature format: SHA256(timestamp + SECRET_KEY)
|
||||
"""
|
||||
timestamp = request.headers.get("X-Timestamp") or request.query_params.get("timestamp")
|
||||
client_sign = request.headers.get("X-Sign") or request.query_params.get("sign")
|
||||
|
||||
if not timestamp or not client_sign:
|
||||
raise HTTPException(status_code=400, detail="Missing X-Timestamp or X-Sign header/query parameter.")
|
||||
|
||||
try:
|
||||
current_time = int(time.time())
|
||||
if abs(current_time - int(timestamp)) > 300:
|
||||
raise HTTPException(status_code=400, detail="Request expired or timestamp is too far in the future.")
|
||||
|
||||
message = f"{timestamp}{SECRET_KEY}".encode('utf-8')
|
||||
server_sign = hmac.new(SECRET_KEY.encode('utf-8'), message, hashlib.sha256).hexdigest()
|
||||
|
||||
if server_sign != client_sign:
|
||||
raise HTTPException(status_code=401, detail="Invalid signature.")
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid timestamp format.")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Signature verification error: {e}")
|
||||
|
||||
async def _generate_podcast_task(
|
||||
task_id: UUID,
|
||||
auth_id: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
input_txt_content: str,
|
||||
tts_providers_config_content: str,
|
||||
podUsers_json_content: str,
|
||||
threads: int,
|
||||
tts_provider: str
|
||||
):
|
||||
task_results[auth_id][task_id]["status"] = TaskStatus.RUNNING
|
||||
try:
|
||||
parser = argparse.ArgumentParser(description="Generate podcast script and audio using OpenAI and local TTS.")
|
||||
parser.add_argument("--api-key", default=api_key, help="OpenAI API key.")
|
||||
parser.add_argument("--base-url", default=base_url, help="OpenAI API base URL (default: https://api.openai.com/v1).")
|
||||
parser.add_argument("--model", default=model, help="OpenAI model to use (default: gpt-3.5-turbo).")
|
||||
parser.add_argument("--threads", type=int, default=threads, help="Number of threads to use for audio generation (default: 1).")
|
||||
args = parser.parse_args([])
|
||||
|
||||
actual_config_path = tts_provider_map.get(tts_provider)
|
||||
if not actual_config_path:
|
||||
raise ValueError(f"Invalid tts_provider: {tts_provider}.") # Changed from HTTPException to ValueError
|
||||
|
||||
output_filepath = await asyncio.to_thread(
|
||||
generate_podcast_audio,
|
||||
args=args,
|
||||
config_path=actual_config_path,
|
||||
input_txt_content=input_txt_content.strip(),
|
||||
tts_providers_config_content=tts_providers_config_content.strip(),
|
||||
podUsers_json_content=podUsers_json_content.strip()
|
||||
)
|
||||
task_results[auth_id][task_id]["status"] = TaskStatus.COMPLETED
|
||||
task_results[auth_id][task_id]["result"] = output_filepath
|
||||
print(f"\nPodcast generation completed for task {task_id}. Output file: {output_filepath}")
|
||||
except Exception as e:
|
||||
task_results[auth_id][task_id]["status"] = TaskStatus.FAILED
|
||||
task_results[auth_id][task_id]["result"] = str(e)
|
||||
print(f"\nPodcast generation failed for task {task_id}: {e}")
|
||||
|
||||
# @app.post("/generate-podcast", dependencies=[Depends(verify_signature)])
|
||||
@app.post("/generate-podcast")
|
||||
async def generate_podcast_submission(
|
||||
background_tasks: BackgroundTasks,
|
||||
auth_id: str = Depends(get_auth_id),
|
||||
api_key: str = Form("OpenAI API key."),
|
||||
base_url: str = Form("https://api.openai.com/v1"),
|
||||
model: str = Form("gpt-3.5-turbo"),
|
||||
input_txt_content: str = Form(...),
|
||||
tts_providers_config_content: str = Form(...),
|
||||
podUsers_json_content: str = Form(...),
|
||||
threads: int = Form(1),
|
||||
tts_provider: str = Form("index-tts")
|
||||
):
|
||||
# 1. Validate tts_provider
|
||||
if tts_provider not in tts_provider_map:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid tts_provider: {tts_provider}.")
|
||||
|
||||
# 2. Check for existing running tasks for this auth_id
|
||||
if auth_id in task_results:
|
||||
for existing_task_id, existing_task_info in task_results[auth_id].items():
|
||||
if existing_task_info["status"] == TaskStatus.RUNNING or existing_task_info["status"] == TaskStatus.PENDING:
|
||||
raise HTTPException(status_code=409, detail=f"There is already a running task (ID: {existing_task_id}) for this auth_id. Please wait for it to complete.")
|
||||
|
||||
task_id = uuid.uuid4()
|
||||
if auth_id not in task_results:
|
||||
task_results[auth_id] = {}
|
||||
task_results[auth_id][task_id] = {
|
||||
"status": TaskStatus.PENDING,
|
||||
"result": None,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
background_tasks.add_task(
|
||||
_generate_podcast_task,
|
||||
task_id,
|
||||
auth_id,
|
||||
api_key,
|
||||
base_url,
|
||||
model,
|
||||
input_txt_content,
|
||||
tts_providers_config_content,
|
||||
podUsers_json_content,
|
||||
threads,
|
||||
tts_provider
|
||||
)
|
||||
|
||||
return {"message": "Podcast generation started.", "task_id": task_id}
|
||||
|
||||
# @app.get("/podcast-status", dependencies=[Depends(verify_signature)])
|
||||
@app.get("/podcast-status")
|
||||
async def get_podcast_status(
|
||||
auth_id: str = Depends(get_auth_id)
|
||||
):
|
||||
if auth_id not in task_results:
|
||||
return {"message": "No tasks found for this auth_id.", "tasks": []}
|
||||
|
||||
all_tasks_for_auth_id = []
|
||||
for task_id, task_info in task_results[auth_id].items():
|
||||
all_tasks_for_auth_id.append({
|
||||
"task_id": task_id,
|
||||
"status": task_info["status"],
|
||||
"result": task_info["result"] if task_info["status"] == TaskStatus.COMPLETED else None,
|
||||
"error": task_info["result"] if task_info["status"] == TaskStatus.FAILED else None,
|
||||
"timestamp": task_info["timestamp"]
|
||||
})
|
||||
return {"message": "Tasks retrieved successfully.", "tasks": all_tasks_for_auth_id}
|
||||
|
||||
@app.get("/download-podcast/")
|
||||
async def download_podcast(file_name: str):
|
||||
file_path = os.path.join(output_dir, file_name)
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=404, detail="File not found.")
|
||||
return FileResponse(file_path, media_type='audio/mpeg', filename=file_name)
|
||||
|
||||
@app.get("/get-voices")
|
||||
async def get_voices(tts_provider: str = "tts"):
|
||||
|
||||
config_path = tts_provider_map.get(tts_provider)
|
||||
if not config_path:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid tts_provider: {tts_provider}.")
|
||||
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
voices = config_data.get("voices")
|
||||
if voices is None:
|
||||
raise HTTPException(status_code=404, detail=f"No 'voices' key found in config for {tts_provider}.")
|
||||
|
||||
return {"tts_provider": tts_provider, "voices": voices}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Config file not found for {tts_provider}: {config_path}")
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=500, detail=f"Error decoding JSON from config file for {tts_provider}: {config_path}. Please check file format.")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
|
||||
|
||||
@app.get("/")
|
||||
async def read_root():
|
||||
return {"message": "FastAPI server is running!"}
|
||||
|
||||
def run_scheduler():
|
||||
"""Runs the scheduler in a loop until the stop event is set."""
|
||||
while not stop_scheduler_event.is_set():
|
||||
schedule.run_pending()
|
||||
time.sleep(1) # Check every second for new jobs or stop event
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global scheduler_thread
|
||||
# Ensure the output directory exists
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# Schedule the cleaning task to run every 30 minutes
|
||||
schedule.every(30).minutes.do(clean_output_directory)
|
||||
# Start the scheduler in a separate thread
|
||||
if scheduler_thread is None or not scheduler_thread.is_alive():
|
||||
scheduler_thread = threading.Thread(target=run_scheduler, daemon=True)
|
||||
scheduler_thread.start()
|
||||
print("FastAPI app started. Scheduled output directory cleaning.")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
global scheduler_thread
|
||||
# Signal the scheduler thread to stop
|
||||
stop_scheduler_event.set()
|
||||
# Wait for the scheduler thread to finish (optional, but good practice)
|
||||
if scheduler_thread is not None and scheduler_thread.is_alive():
|
||||
scheduler_thread.join(timeout=5) # Wait for max 5 seconds
|
||||
if scheduler_thread.is_alive():
|
||||
print("Scheduler thread did not terminate gracefully.")
|
||||
print("FastAPI app shutting down.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user