feat(api): 添加回调功能并重构API接口

- 在生成播客任务中添加回调URL支持,任务完成后会通知指定URL
- 将`generate_podcast_audio`重命名为`generate_podcast_audio_api`并返回完整结果
- 使用`asynccontextmanager`替代已弃用的`startup`/`shutdown`事件
- 改进错误处理,添加任务取消逻辑
- 更新状态接口返回更多任务详情
This commit is contained in:
hex2077
2025-08-13 14:41:27 +08:00
parent c2930e4340
commit 1242adb0e6
2 changed files with 139 additions and 68 deletions

171
main.py
View File

@@ -15,8 +15,10 @@ from enum import Enum
import shutil
import schedule
import threading
from contextlib import asynccontextmanager # 导入 asynccontextmanager
import httpx # 导入 httpx 库
from podcast_generator import generate_podcast_audio
from podcast_generator import generate_podcast_audio_api
class TaskStatus(str, Enum):
PENDING = "pending"
@@ -24,18 +26,52 @@ class TaskStatus(str, Enum):
COMPLETED = "completed"
FAILED = "failed"
app = FastAPI()
# --- 新的 Lifespan 上下文管理器 ---
# 这是替代已弃用的 on_event("startup") 和 on_event("shutdown") 的新方法
@asynccontextmanager
async def lifespan(app: FastAPI):
# 在应用启动时运行的代码 (等同于 startup_event)
print("FastAPI app is starting up...")
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 安排清理任务每30分钟运行一次
schedule.every(30).minutes.do(clean_output_directory)
# 在单独的线程中启动调度器
scheduler_thread = threading.Thread(target=run_scheduler, daemon=True)
scheduler_thread.start()
print("FastAPI app started. Output directory cleaning is scheduled.")
# `yield` 语句是分割点,应用在这里运行
yield
# 在应用关闭时运行的代码 (等同于 shutdown_event)
print("FastAPI app is shutting down...")
# 发送信号让调度器线程停止
stop_scheduler_event.set()
# 等待调度器线程结束(可选,但推荐)
# 注意:在 lifespan 中,我们无法直接访问在启动部分创建的 scheduler_thread 局部变量
# 因此,我们仍然需要一个全局事件标志来通信。
# 线程本身是守护线程(daemon=True),如果主程序退出它也会被强制终止,
# 但优雅地停止是更好的实践。
print("Signaled scheduler to stop. Main application will now exit.")
# Global flag to signal the scheduler thread to stop
# 在创建 FastAPI 实例时,传入 lifespan 函数
app = FastAPI(lifespan=lifespan)
# 全局标志,用于通知调度器线程停止
stop_scheduler_event = threading.Event()
# Global reference for the scheduler thread
scheduler_thread = None
# Global configuration
# 全局配置
output_dir = "output"
# Define a function to clean the output directory
# 定义一个函数来清理输出目录
def clean_output_directory():
"""Removes files from the output directory that are older than 30 minutes."""
print(f"Cleaning output directory: {output_dir}")
@@ -47,24 +83,24 @@ def clean_output_directory():
file_path = os.path.join(output_dir, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
# Get last modification time
# 获取最后修改时间
if now - os.path.getmtime(file_path) > threshold:
os.unlink(file_path)
print(f"Deleted old file: {file_path}")
elif os.path.isdir(file_path):
# Optionally, recursively delete old subdirectories or files within them
# For now, just skip directories
# 可选地,递归删除旧的子目录或其中的文件
# 目前只跳过目录
pass
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
# In-memory store for task results
# 内存中存储任务结果
# {task_id: {"auth_id": auth_id, "status": TaskStatus, "result": any, "timestamp": float}}
task_results: Dict[str, Dict[UUID, Dict]] = {}
# Configuration for signature verification
SECRET_KEY = os.getenv("PODCAST_API_SECRET_KEY", "your-super-secret-key") # Change this in production!
# Define a mapping from tts_provider names to their config file paths
# 签名验证配置
SECRET_KEY = os.getenv("PODCAST_API_SECRET_KEY", "your-super-secret-key") # 在生产环境中请务必修改!
# 定义从 tts_provider 名称到其配置文件路径的映射
tts_provider_map = {
"index-tts": "config/index-tts.json",
"doubao-tts": "config/doubao-tts.json",
@@ -76,7 +112,7 @@ tts_provider_map = {
async def get_auth_id(x_auth_id: str = Header(..., alias="X-Auth-Id")):
"""
Dependency to get X-Auth-Id from headers.
从头部获取 X-Auth-Id 的依赖项。
"""
if not x_auth_id:
raise HTTPException(status_code=400, detail="Missing X-Auth-Id header.")
@@ -84,8 +120,8 @@ async def get_auth_id(x_auth_id: str = Header(..., alias="X-Auth-Id")):
async def verify_signature(request: Request):
"""
Verify the 'sign' parameter in the request headers or query parameters.
Expected signature format: SHA256(timestamp + SECRET_KEY)
验证请求头或查询参数中的 'sign' 参数。
期望的签名格式: SHA256(timestamp + SECRET_KEY)
"""
timestamp = request.headers.get("X-Timestamp") or request.query_params.get("timestamp")
client_sign = request.headers.get("X-Sign") or request.query_params.get("sign")
@@ -95,7 +131,7 @@ async def verify_signature(request: Request):
try:
current_time = int(time.time())
if abs(current_time - int(timestamp)) > 300:
if abs(current_time - int(timestamp)) > 300: # 请求在5分钟内有效
raise HTTPException(status_code=400, detail="Request expired or timestamp is too far in the future.")
message = f"{timestamp}{SECRET_KEY}".encode('utf-8')
@@ -118,7 +154,8 @@ async def _generate_podcast_task(
tts_providers_config_content: str,
podUsers_json_content: str,
threads: int,
tts_provider: str
tts_provider: str,
callback_url: Optional[str] = None # 新增回调地址参数
):
task_results[auth_id][task_id]["status"] = TaskStatus.RUNNING
try:
@@ -131,10 +168,10 @@ async def _generate_podcast_task(
actual_config_path = tts_provider_map.get(tts_provider)
if not actual_config_path:
raise ValueError(f"Invalid tts_provider: {tts_provider}.") # Changed from HTTPException to ValueError
raise ValueError(f"Invalid tts_provider: {tts_provider}.")
output_filepath = await asyncio.to_thread(
generate_podcast_audio,
podcast_generation_results = await asyncio.to_thread(
generate_podcast_audio_api,
args=args,
config_path=actual_config_path,
input_txt_content=input_txt_content.strip(),
@@ -142,12 +179,44 @@ async def _generate_podcast_task(
podUsers_json_content=podUsers_json_content.strip()
)
task_results[auth_id][task_id]["status"] = TaskStatus.COMPLETED
task_results[auth_id][task_id]["result"] = output_filepath
print(f"\nPodcast generation completed for task {task_id}. Output file: {output_filepath}")
task_results[auth_id][task_id].update(podcast_generation_results)
print(f"\nPodcast generation completed for task {task_id}. Output file: {podcast_generation_results.get('output_audio_filepath')}")
except Exception as e:
task_results[auth_id][task_id]["status"] = TaskStatus.FAILED
task_results[auth_id][task_id]["result"] = str(e)
print(f"\nPodcast generation failed for task {task_id}: {e}")
finally: # 无论成功或失败,都尝试调用回调
if callback_url:
print(f"Attempting to send callback for task {task_id} to {callback_url}")
callback_data = {
"task_id": str(task_id),
"auth_id": auth_id,
"task_results": task_results[auth_id][task_id],
"timestamp": time.time(),
}
MAX_RETRIES = 3 # 定义最大重试次数
RETRY_DELAY = 5 # 定义重试间隔(秒)
for attempt in range(MAX_RETRIES + 1): # 尝试次数从0到MAX_RETRIES
try:
async with httpx.AsyncClient() as client:
response = await client.post(callback_url, json=callback_data, timeout=30.0)
response.raise_for_status() # 对 4xx/5xx 响应抛出异常
print(f"Callback successfully sent for task {task_id} on attempt {attempt + 1}. Status: {response.status_code}")
break # 成功发送,跳出循环
except httpx.RequestError as req_err:
print(f"Callback request failed for task {task_id} to {callback_url} on attempt {attempt + 1}: {req_err}")
except httpx.HTTPStatusError as http_err:
print(f"Callback received error response for task {task_id} from {callback_url} on attempt {attempt + 1}: {http_err.response.status_code} - {http_err.response.text}")
except Exception as cb_err:
print(f"An unexpected error occurred during callback for task {task_id} on attempt {attempt + 1}: {cb_err}")
if attempt < MAX_RETRIES:
print(f"Retrying callback for task {task_id} in {RETRY_DELAY} seconds...")
await asyncio.sleep(RETRY_DELAY)
else:
print(f"Callback failed for task {task_id} after {MAX_RETRIES} attempts.")
# @app.post("/generate-podcast", dependencies=[Depends(verify_signature)])
@app.post("/generate-podcast")
@@ -161,16 +230,17 @@ async def generate_podcast_submission(
tts_providers_config_content: str = Form(...),
podUsers_json_content: str = Form(...),
threads: int = Form(1),
tts_provider: str = Form("index-tts")
tts_provider: str = Form("index-tts"),
callback_url: Optional[str] = Form(None) # 新增回调地址参数
):
# 1. Validate tts_provider
# 1. 验证 tts_provider
if tts_provider not in tts_provider_map:
raise HTTPException(status_code=400, detail=f"Invalid tts_provider: {tts_provider}.")
# 2. Check for existing running tasks for this auth_id
# 2. 检查此 auth_id 是否有正在运行的任务
if auth_id in task_results:
for existing_task_id, existing_task_info in task_results[auth_id].items():
if existing_task_info["status"] == TaskStatus.RUNNING or existing_task_info["status"] == TaskStatus.PENDING:
if existing_task_info["status"] in (TaskStatus.RUNNING, TaskStatus.PENDING):
raise HTTPException(status_code=409, detail=f"There is already a running task (ID: {existing_task_id}) for this auth_id. Please wait for it to complete.")
task_id = uuid.uuid4()
@@ -179,7 +249,8 @@ async def generate_podcast_submission(
task_results[auth_id][task_id] = {
"status": TaskStatus.PENDING,
"result": None,
"timestamp": time.time()
"timestamp": time.time(),
"callback_url": callback_url # 存储回调地址
}
background_tasks.add_task(
@@ -193,7 +264,8 @@ async def generate_podcast_submission(
tts_providers_config_content,
podUsers_json_content,
threads,
tts_provider
tts_provider,
callback_url # 传递回调地址
)
return {"message": "Podcast generation started.", "task_id": task_id}
@@ -207,11 +279,14 @@ async def get_podcast_status(
return {"message": "No tasks found for this auth_id.", "tasks": []}
all_tasks_for_auth_id = []
for task_id, task_info in task_results[auth_id].items():
for task_id, task_info in task_results.get(auth_id, {}).items():
all_tasks_for_auth_id.append({
"task_id": task_id,
"status": task_info["status"],
"result": task_info["result"] if task_info["status"] == TaskStatus.COMPLETED else None,
"podUsers": task_info.get("podUsers"),
"output_audio_filepath": task_info.get("output_audio_filepath"),
"overview_content": task_info.get("overview_content"),
"podcast_script": task_info.get("podcast_script"),
"error": task_info["result"] if task_info["status"] == TaskStatus.FAILED else None,
"timestamp": task_info["timestamp"]
})
@@ -226,7 +301,6 @@ async def download_podcast(file_name: str):
@app.get("/get-voices")
async def get_voices(tts_provider: str = "tts"):
config_path = tts_provider_map.get(tts_provider)
if not config_path:
raise HTTPException(status_code=400, detail=f"Invalid tts_provider: {tts_provider}.")
@@ -252,35 +326,10 @@ async def read_root():
return {"message": "FastAPI server is running!"}
def run_scheduler():
"""Runs the scheduler in a loop until the stop event is set."""
"""在循环中运行调度器,直到设置停止事件。"""
while not stop_scheduler_event.is_set():
schedule.run_pending()
time.sleep(1) # Check every second for new jobs or stop event
@app.on_event("startup")
async def startup_event():
global scheduler_thread
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)
# Schedule the cleaning task to run every 30 minutes
schedule.every(30).minutes.do(clean_output_directory)
# Start the scheduler in a separate thread
if scheduler_thread is None or not scheduler_thread.is_alive():
scheduler_thread = threading.Thread(target=run_scheduler, daemon=True)
scheduler_thread.start()
print("FastAPI app started. Scheduled output directory cleaning.")
@app.on_event("shutdown")
async def shutdown_event():
global scheduler_thread
# Signal the scheduler thread to stop
stop_scheduler_event.set()
# Wait for the scheduler thread to finish (optional, but good practice)
if scheduler_thread is not None and scheduler_thread.is_alive():
scheduler_thread.join(timeout=5) # Wait for max 5 seconds
if scheduler_thread.is_alive():
print("Scheduler thread did not terminate gracefully.")
print("FastAPI app shutting down.")
time.sleep(1) # 每秒检查一次新任务或停止事件
if __name__ == "__main__":
import uvicorn

View File

@@ -364,6 +364,7 @@ def _generate_all_audio_files(podcast_script, config_data, tts_adapter: TTSAdapt
for i, item in enumerate(transcripts)
}
exception_caught = None
for future in as_completed(future_to_index):
index = future_to_index[future]
try:
@@ -371,8 +372,17 @@ def _generate_all_audio_files(podcast_script, config_data, tts_adapter: TTSAdapt
if result:
audio_files_dict[index] = result
except Exception as e:
# Re-raise the exception to propagate it to the main thread
raise RuntimeError(f"Error generating audio for item {index}: {e}")
exception_caught = RuntimeError(f"Error generating audio for item {index}: {e}")
# An error occurred, we should stop.
break
# If we broke out of the loop due to an exception, cancel other futures.
if exception_caught:
print(f"An error occurred: {exception_caught}. Cancelling outstanding tasks.")
for f in future_to_index:
if not f.done():
f.cancel()
raise exception_caught
audio_files = [audio_files_dict[i] for i in sorted(audio_files_dict.keys())]
@@ -459,7 +469,7 @@ def _initialize_tts_adapter(config_data: dict, tts_providers_config_content: Opt
else:
raise ValueError(f"Unsupported TTS provider: {tts_provider}")
def main():
def generate_podcast_audio():
args = _parse_arguments()
config_data = _load_configuration()
api_key, base_url, model = _prepare_openai_settings(args, config_data)
@@ -479,9 +489,16 @@ def main():
audio_files = _generate_all_audio_files(podcast_script, config_data, tts_adapter, args.threads)
_create_ffmpeg_file_list(audio_files)
output_audio_filepath = merge_audio_files()
return {
"output_audio_filepath": output_audio_filepath,
"overview_content": overview_content,
"podcast_script": podcast_script,
"podUsers": pod_users,
}
def generate_podcast_audio(args, config_path: str, input_txt_content: str, tts_providers_config_content: str, podUsers_json_content: str) -> str:
def generate_podcast_audio_api(args, config_path: str, input_txt_content: str, tts_providers_config_content: str, podUsers_json_content: str) -> dict:
"""
Generates a podcast audio file based on the provided parameters.
@@ -519,14 +536,19 @@ def generate_podcast_audio(args, config_path: str, input_txt_content: str, tts_p
_create_ffmpeg_file_list(audio_files)
output_audio_filepath = merge_audio_files()
return output_audio_filepath
task_results = {
"output_audio_filepath": output_audio_filepath,
"overview_content": overview_content,
"podcast_script": podcast_script,
"podUsers": podUsers,
}
return task_results
if __name__ == "__main__":
start_time = time.time()
try:
main()
merge_audio_files()
generate_podcast_audio()
except Exception as e:
print(f"\nError: An unexpected error occurred during podcast generation: {e}", file=sys.stderr)
sys.exit(1)