Files
Podcast-Generator/main.py
hex2077 1242adb0e6 feat(api): 添加回调功能并重构API接口
- 在生成播客任务中添加回调URL支持,任务完成后会通知指定URL
- 将`generate_podcast_audio`重命名为`generate_podcast_audio_api`并返回完整结果
- 使用`asynccontextmanager`替代已弃用的`startup`/`shutdown`事件
- 改进错误处理,添加任务取消逻辑
- 更新状态接口返回更多任务详情
2025-08-13 14:41:27 +08:00

336 lines
14 KiB
Python

from fastapi import FastAPI, Request, HTTPException, Depends, Form, Header
from fastapi.responses import FileResponse, JSONResponse
from typing import Optional, Dict
import uuid
import asyncio
from starlette.background import BackgroundTasks
from uuid import UUID
import hashlib
import hmac
import time
import os
import json
import argparse
from enum import Enum
import shutil
import schedule
import threading
from contextlib import asynccontextmanager # 导入 asynccontextmanager
import httpx # 导入 httpx 库
from podcast_generator import generate_podcast_audio_api
class TaskStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
# --- 新的 Lifespan 上下文管理器 ---
# 这是替代已弃用的 on_event("startup") 和 on_event("shutdown") 的新方法
@asynccontextmanager
async def lifespan(app: FastAPI):
# 在应用启动时运行的代码 (等同于 startup_event)
print("FastAPI app is starting up...")
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 安排清理任务每30分钟运行一次
schedule.every(30).minutes.do(clean_output_directory)
# 在单独的线程中启动调度器
scheduler_thread = threading.Thread(target=run_scheduler, daemon=True)
scheduler_thread.start()
print("FastAPI app started. Output directory cleaning is scheduled.")
# `yield` 语句是分割点,应用在这里运行
yield
# 在应用关闭时运行的代码 (等同于 shutdown_event)
print("FastAPI app is shutting down...")
# 发送信号让调度器线程停止
stop_scheduler_event.set()
# 等待调度器线程结束(可选,但推荐)
# 注意:在 lifespan 中,我们无法直接访问在启动部分创建的 scheduler_thread 局部变量
# 因此,我们仍然需要一个全局事件标志来通信。
# 线程本身是守护线程(daemon=True),如果主程序退出它也会被强制终止,
# 但优雅地停止是更好的实践。
print("Signaled scheduler to stop. Main application will now exit.")
# 在创建 FastAPI 实例时,传入 lifespan 函数
app = FastAPI(lifespan=lifespan)
# 全局标志,用于通知调度器线程停止
stop_scheduler_event = threading.Event()
# 全局配置
output_dir = "output"
# 定义一个函数来清理输出目录
def clean_output_directory():
"""Removes files from the output directory that are older than 30 minutes."""
print(f"Cleaning output directory: {output_dir}")
now = time.time()
# 30 minutes in seconds
threshold = 30 * 60
for filename in os.listdir(output_dir):
file_path = os.path.join(output_dir, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
# 获取最后修改时间
if now - os.path.getmtime(file_path) > threshold:
os.unlink(file_path)
print(f"Deleted old file: {file_path}")
elif os.path.isdir(file_path):
# 可选地,递归删除旧的子目录或其中的文件
# 目前只跳过目录
pass
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
# 内存中存储任务结果
# {task_id: {"auth_id": auth_id, "status": TaskStatus, "result": any, "timestamp": float}}
task_results: Dict[str, Dict[UUID, Dict]] = {}
# 签名验证配置
SECRET_KEY = os.getenv("PODCAST_API_SECRET_KEY", "your-super-secret-key") # 在生产环境中请务必修改!
# 定义从 tts_provider 名称到其配置文件路径的映射
tts_provider_map = {
"index-tts": "config/index-tts.json",
"doubao-tts": "config/doubao-tts.json",
"edge-tts": "config/edge-tts.json",
"fish-audio": "config/fish-audio.json",
"gemini-tts": "config/gemini-tts.json",
"minimax": "config/minimax.json",
}
async def get_auth_id(x_auth_id: str = Header(..., alias="X-Auth-Id")):
"""
从头部获取 X-Auth-Id 的依赖项。
"""
if not x_auth_id:
raise HTTPException(status_code=400, detail="Missing X-Auth-Id header.")
return x_auth_id
async def verify_signature(request: Request):
"""
验证请求头或查询参数中的 'sign' 参数。
期望的签名格式: SHA256(timestamp + SECRET_KEY)
"""
timestamp = request.headers.get("X-Timestamp") or request.query_params.get("timestamp")
client_sign = request.headers.get("X-Sign") or request.query_params.get("sign")
if not timestamp or not client_sign:
raise HTTPException(status_code=400, detail="Missing X-Timestamp or X-Sign header/query parameter.")
try:
current_time = int(time.time())
if abs(current_time - int(timestamp)) > 300: # 请求在5分钟内有效
raise HTTPException(status_code=400, detail="Request expired or timestamp is too far in the future.")
message = f"{timestamp}{SECRET_KEY}".encode('utf-8')
server_sign = hmac.new(SECRET_KEY.encode('utf-8'), message, hashlib.sha256).hexdigest()
if server_sign != client_sign:
raise HTTPException(status_code=401, detail="Invalid signature.")
except ValueError:
raise HTTPException(status_code=400, detail="Invalid timestamp format.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Signature verification error: {e}")
async def _generate_podcast_task(
task_id: UUID,
auth_id: str,
api_key: str,
base_url: str,
model: str,
input_txt_content: str,
tts_providers_config_content: str,
podUsers_json_content: str,
threads: int,
tts_provider: str,
callback_url: Optional[str] = None # 新增回调地址参数
):
task_results[auth_id][task_id]["status"] = TaskStatus.RUNNING
try:
parser = argparse.ArgumentParser(description="Generate podcast script and audio using OpenAI and local TTS.")
parser.add_argument("--api-key", default=api_key, help="OpenAI API key.")
parser.add_argument("--base-url", default=base_url, help="OpenAI API base URL (default: https://api.openai.com/v1).")
parser.add_argument("--model", default=model, help="OpenAI model to use (default: gpt-3.5-turbo).")
parser.add_argument("--threads", type=int, default=threads, help="Number of threads to use for audio generation (default: 1).")
args = parser.parse_args([])
actual_config_path = tts_provider_map.get(tts_provider)
if not actual_config_path:
raise ValueError(f"Invalid tts_provider: {tts_provider}.")
podcast_generation_results = await asyncio.to_thread(
generate_podcast_audio_api,
args=args,
config_path=actual_config_path,
input_txt_content=input_txt_content.strip(),
tts_providers_config_content=tts_providers_config_content.strip(),
podUsers_json_content=podUsers_json_content.strip()
)
task_results[auth_id][task_id]["status"] = TaskStatus.COMPLETED
task_results[auth_id][task_id].update(podcast_generation_results)
print(f"\nPodcast generation completed for task {task_id}. Output file: {podcast_generation_results.get('output_audio_filepath')}")
except Exception as e:
task_results[auth_id][task_id]["status"] = TaskStatus.FAILED
task_results[auth_id][task_id]["result"] = str(e)
print(f"\nPodcast generation failed for task {task_id}: {e}")
finally: # 无论成功或失败,都尝试调用回调
if callback_url:
print(f"Attempting to send callback for task {task_id} to {callback_url}")
callback_data = {
"task_id": str(task_id),
"auth_id": auth_id,
"task_results": task_results[auth_id][task_id],
"timestamp": time.time(),
}
MAX_RETRIES = 3 # 定义最大重试次数
RETRY_DELAY = 5 # 定义重试间隔(秒)
for attempt in range(MAX_RETRIES + 1): # 尝试次数从0到MAX_RETRIES
try:
async with httpx.AsyncClient() as client:
response = await client.post(callback_url, json=callback_data, timeout=30.0)
response.raise_for_status() # 对 4xx/5xx 响应抛出异常
print(f"Callback successfully sent for task {task_id} on attempt {attempt + 1}. Status: {response.status_code}")
break # 成功发送,跳出循环
except httpx.RequestError as req_err:
print(f"Callback request failed for task {task_id} to {callback_url} on attempt {attempt + 1}: {req_err}")
except httpx.HTTPStatusError as http_err:
print(f"Callback received error response for task {task_id} from {callback_url} on attempt {attempt + 1}: {http_err.response.status_code} - {http_err.response.text}")
except Exception as cb_err:
print(f"An unexpected error occurred during callback for task {task_id} on attempt {attempt + 1}: {cb_err}")
if attempt < MAX_RETRIES:
print(f"Retrying callback for task {task_id} in {RETRY_DELAY} seconds...")
await asyncio.sleep(RETRY_DELAY)
else:
print(f"Callback failed for task {task_id} after {MAX_RETRIES} attempts.")
# @app.post("/generate-podcast", dependencies=[Depends(verify_signature)])
@app.post("/generate-podcast")
async def generate_podcast_submission(
background_tasks: BackgroundTasks,
auth_id: str = Depends(get_auth_id),
api_key: str = Form("OpenAI API key."),
base_url: str = Form("https://api.openai.com/v1"),
model: str = Form("gpt-3.5-turbo"),
input_txt_content: str = Form(...),
tts_providers_config_content: str = Form(...),
podUsers_json_content: str = Form(...),
threads: int = Form(1),
tts_provider: str = Form("index-tts"),
callback_url: Optional[str] = Form(None) # 新增回调地址参数
):
# 1. 验证 tts_provider
if tts_provider not in tts_provider_map:
raise HTTPException(status_code=400, detail=f"Invalid tts_provider: {tts_provider}.")
# 2. 检查此 auth_id 是否有正在运行的任务
if auth_id in task_results:
for existing_task_id, existing_task_info in task_results[auth_id].items():
if existing_task_info["status"] in (TaskStatus.RUNNING, TaskStatus.PENDING):
raise HTTPException(status_code=409, detail=f"There is already a running task (ID: {existing_task_id}) for this auth_id. Please wait for it to complete.")
task_id = uuid.uuid4()
if auth_id not in task_results:
task_results[auth_id] = {}
task_results[auth_id][task_id] = {
"status": TaskStatus.PENDING,
"result": None,
"timestamp": time.time(),
"callback_url": callback_url # 存储回调地址
}
background_tasks.add_task(
_generate_podcast_task,
task_id,
auth_id,
api_key,
base_url,
model,
input_txt_content,
tts_providers_config_content,
podUsers_json_content,
threads,
tts_provider,
callback_url # 传递回调地址
)
return {"message": "Podcast generation started.", "task_id": task_id}
# @app.get("/podcast-status", dependencies=[Depends(verify_signature)])
@app.get("/podcast-status")
async def get_podcast_status(
auth_id: str = Depends(get_auth_id)
):
if auth_id not in task_results:
return {"message": "No tasks found for this auth_id.", "tasks": []}
all_tasks_for_auth_id = []
for task_id, task_info in task_results.get(auth_id, {}).items():
all_tasks_for_auth_id.append({
"task_id": task_id,
"status": task_info["status"],
"podUsers": task_info.get("podUsers"),
"output_audio_filepath": task_info.get("output_audio_filepath"),
"overview_content": task_info.get("overview_content"),
"podcast_script": task_info.get("podcast_script"),
"error": task_info["result"] if task_info["status"] == TaskStatus.FAILED else None,
"timestamp": task_info["timestamp"]
})
return {"message": "Tasks retrieved successfully.", "tasks": all_tasks_for_auth_id}
@app.get("/download-podcast/")
async def download_podcast(file_name: str):
file_path = os.path.join(output_dir, file_name)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="File not found.")
return FileResponse(file_path, media_type='audio/mpeg', filename=file_name)
@app.get("/get-voices")
async def get_voices(tts_provider: str = "tts"):
config_path = tts_provider_map.get(tts_provider)
if not config_path:
raise HTTPException(status_code=400, detail=f"Invalid tts_provider: {tts_provider}.")
try:
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
voices = config_data.get("voices")
if voices is None:
raise HTTPException(status_code=404, detail=f"No 'voices' key found in config for {tts_provider}.")
return {"tts_provider": tts_provider, "voices": voices}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Config file not found for {tts_provider}: {config_path}")
except json.JSONDecodeError:
raise HTTPException(status_code=500, detail=f"Error decoding JSON from config file for {tts_provider}: {config_path}. Please check file format.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
@app.get("/")
async def read_root():
return {"message": "FastAPI server is running!"}
def run_scheduler():
"""在循环中运行调度器,直到设置停止事件。"""
while not stop_scheduler_event.is_set():
schedule.run_pending()
time.sleep(1) # 每秒检查一次新任务或停止事件
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)