Files
Podcast-Generator/main.py
hex2077 c2930e4340 feat(api): 新增FastAPI服务接口及完善TTS配置管理
实现FastAPI服务接口,支持播客生成任务提交、状态查询和音频下载功能
重构TTS配置管理,统一处理不同TTS服务商的API URL配置
更新README文档,添加API使用说明和项目徽章
添加定时清理输出目录功能,优化资源管理
2025-08-11 22:09:18 +08:00

287 lines
11 KiB
Python

from fastapi import FastAPI, Request, HTTPException, Depends, Form, Header
from fastapi.responses import FileResponse, JSONResponse
from typing import Optional, Dict
import uuid
import asyncio
from starlette.background import BackgroundTasks
from uuid import UUID
import hashlib
import hmac
import time
import os
import json
import argparse
from enum import Enum
import shutil
import schedule
import threading
from podcast_generator import generate_podcast_audio
class TaskStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
app = FastAPI()
# Global flag to signal the scheduler thread to stop
stop_scheduler_event = threading.Event()
# Global reference for the scheduler thread
scheduler_thread = None
# Global configuration
output_dir = "output"
# Define a function to clean the output directory
def clean_output_directory():
"""Removes files from the output directory that are older than 30 minutes."""
print(f"Cleaning output directory: {output_dir}")
now = time.time()
# 30 minutes in seconds
threshold = 30 * 60
for filename in os.listdir(output_dir):
file_path = os.path.join(output_dir, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
# Get last modification time
if now - os.path.getmtime(file_path) > threshold:
os.unlink(file_path)
print(f"Deleted old file: {file_path}")
elif os.path.isdir(file_path):
# Optionally, recursively delete old subdirectories or files within them
# For now, just skip directories
pass
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
# In-memory store for task results
# {task_id: {"auth_id": auth_id, "status": TaskStatus, "result": any, "timestamp": float}}
task_results: Dict[str, Dict[UUID, Dict]] = {}
# Configuration for signature verification
SECRET_KEY = os.getenv("PODCAST_API_SECRET_KEY", "your-super-secret-key") # Change this in production!
# Define a mapping from tts_provider names to their config file paths
tts_provider_map = {
"index-tts": "config/index-tts.json",
"doubao-tts": "config/doubao-tts.json",
"edge-tts": "config/edge-tts.json",
"fish-audio": "config/fish-audio.json",
"gemini-tts": "config/gemini-tts.json",
"minimax": "config/minimax.json",
}
async def get_auth_id(x_auth_id: str = Header(..., alias="X-Auth-Id")):
"""
Dependency to get X-Auth-Id from headers.
"""
if not x_auth_id:
raise HTTPException(status_code=400, detail="Missing X-Auth-Id header.")
return x_auth_id
async def verify_signature(request: Request):
"""
Verify the 'sign' parameter in the request headers or query parameters.
Expected signature format: SHA256(timestamp + SECRET_KEY)
"""
timestamp = request.headers.get("X-Timestamp") or request.query_params.get("timestamp")
client_sign = request.headers.get("X-Sign") or request.query_params.get("sign")
if not timestamp or not client_sign:
raise HTTPException(status_code=400, detail="Missing X-Timestamp or X-Sign header/query parameter.")
try:
current_time = int(time.time())
if abs(current_time - int(timestamp)) > 300:
raise HTTPException(status_code=400, detail="Request expired or timestamp is too far in the future.")
message = f"{timestamp}{SECRET_KEY}".encode('utf-8')
server_sign = hmac.new(SECRET_KEY.encode('utf-8'), message, hashlib.sha256).hexdigest()
if server_sign != client_sign:
raise HTTPException(status_code=401, detail="Invalid signature.")
except ValueError:
raise HTTPException(status_code=400, detail="Invalid timestamp format.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Signature verification error: {e}")
async def _generate_podcast_task(
task_id: UUID,
auth_id: str,
api_key: str,
base_url: str,
model: str,
input_txt_content: str,
tts_providers_config_content: str,
podUsers_json_content: str,
threads: int,
tts_provider: str
):
task_results[auth_id][task_id]["status"] = TaskStatus.RUNNING
try:
parser = argparse.ArgumentParser(description="Generate podcast script and audio using OpenAI and local TTS.")
parser.add_argument("--api-key", default=api_key, help="OpenAI API key.")
parser.add_argument("--base-url", default=base_url, help="OpenAI API base URL (default: https://api.openai.com/v1).")
parser.add_argument("--model", default=model, help="OpenAI model to use (default: gpt-3.5-turbo).")
parser.add_argument("--threads", type=int, default=threads, help="Number of threads to use for audio generation (default: 1).")
args = parser.parse_args([])
actual_config_path = tts_provider_map.get(tts_provider)
if not actual_config_path:
raise ValueError(f"Invalid tts_provider: {tts_provider}.") # Changed from HTTPException to ValueError
output_filepath = await asyncio.to_thread(
generate_podcast_audio,
args=args,
config_path=actual_config_path,
input_txt_content=input_txt_content.strip(),
tts_providers_config_content=tts_providers_config_content.strip(),
podUsers_json_content=podUsers_json_content.strip()
)
task_results[auth_id][task_id]["status"] = TaskStatus.COMPLETED
task_results[auth_id][task_id]["result"] = output_filepath
print(f"\nPodcast generation completed for task {task_id}. Output file: {output_filepath}")
except Exception as e:
task_results[auth_id][task_id]["status"] = TaskStatus.FAILED
task_results[auth_id][task_id]["result"] = str(e)
print(f"\nPodcast generation failed for task {task_id}: {e}")
# @app.post("/generate-podcast", dependencies=[Depends(verify_signature)])
@app.post("/generate-podcast")
async def generate_podcast_submission(
background_tasks: BackgroundTasks,
auth_id: str = Depends(get_auth_id),
api_key: str = Form("OpenAI API key."),
base_url: str = Form("https://api.openai.com/v1"),
model: str = Form("gpt-3.5-turbo"),
input_txt_content: str = Form(...),
tts_providers_config_content: str = Form(...),
podUsers_json_content: str = Form(...),
threads: int = Form(1),
tts_provider: str = Form("index-tts")
):
# 1. Validate tts_provider
if tts_provider not in tts_provider_map:
raise HTTPException(status_code=400, detail=f"Invalid tts_provider: {tts_provider}.")
# 2. Check for existing running tasks for this auth_id
if auth_id in task_results:
for existing_task_id, existing_task_info in task_results[auth_id].items():
if existing_task_info["status"] == TaskStatus.RUNNING or existing_task_info["status"] == TaskStatus.PENDING:
raise HTTPException(status_code=409, detail=f"There is already a running task (ID: {existing_task_id}) for this auth_id. Please wait for it to complete.")
task_id = uuid.uuid4()
if auth_id not in task_results:
task_results[auth_id] = {}
task_results[auth_id][task_id] = {
"status": TaskStatus.PENDING,
"result": None,
"timestamp": time.time()
}
background_tasks.add_task(
_generate_podcast_task,
task_id,
auth_id,
api_key,
base_url,
model,
input_txt_content,
tts_providers_config_content,
podUsers_json_content,
threads,
tts_provider
)
return {"message": "Podcast generation started.", "task_id": task_id}
# @app.get("/podcast-status", dependencies=[Depends(verify_signature)])
@app.get("/podcast-status")
async def get_podcast_status(
auth_id: str = Depends(get_auth_id)
):
if auth_id not in task_results:
return {"message": "No tasks found for this auth_id.", "tasks": []}
all_tasks_for_auth_id = []
for task_id, task_info in task_results[auth_id].items():
all_tasks_for_auth_id.append({
"task_id": task_id,
"status": task_info["status"],
"result": task_info["result"] if task_info["status"] == TaskStatus.COMPLETED else None,
"error": task_info["result"] if task_info["status"] == TaskStatus.FAILED else None,
"timestamp": task_info["timestamp"]
})
return {"message": "Tasks retrieved successfully.", "tasks": all_tasks_for_auth_id}
@app.get("/download-podcast/")
async def download_podcast(file_name: str):
file_path = os.path.join(output_dir, file_name)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="File not found.")
return FileResponse(file_path, media_type='audio/mpeg', filename=file_name)
@app.get("/get-voices")
async def get_voices(tts_provider: str = "tts"):
config_path = tts_provider_map.get(tts_provider)
if not config_path:
raise HTTPException(status_code=400, detail=f"Invalid tts_provider: {tts_provider}.")
try:
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
voices = config_data.get("voices")
if voices is None:
raise HTTPException(status_code=404, detail=f"No 'voices' key found in config for {tts_provider}.")
return {"tts_provider": tts_provider, "voices": voices}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Config file not found for {tts_provider}: {config_path}")
except json.JSONDecodeError:
raise HTTPException(status_code=500, detail=f"Error decoding JSON from config file for {tts_provider}: {config_path}. Please check file format.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
@app.get("/")
async def read_root():
return {"message": "FastAPI server is running!"}
def run_scheduler():
"""Runs the scheduler in a loop until the stop event is set."""
while not stop_scheduler_event.is_set():
schedule.run_pending()
time.sleep(1) # Check every second for new jobs or stop event
@app.on_event("startup")
async def startup_event():
global scheduler_thread
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)
# Schedule the cleaning task to run every 30 minutes
schedule.every(30).minutes.do(clean_output_directory)
# Start the scheduler in a separate thread
if scheduler_thread is None or not scheduler_thread.is_alive():
scheduler_thread = threading.Thread(target=run_scheduler, daemon=True)
scheduler_thread.start()
print("FastAPI app started. Scheduled output directory cleaning.")
@app.on_event("shutdown")
async def shutdown_event():
global scheduler_thread
# Signal the scheduler thread to stop
stop_scheduler_event.set()
# Wait for the scheduler thread to finish (optional, but good practice)
if scheduler_thread is not None and scheduler_thread.is_alive():
scheduler_thread.join(timeout=5) # Wait for max 5 seconds
if scheduler_thread.is_alive():
print("Scheduler thread did not terminate gracefully.")
print("FastAPI app shutting down.")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)