From 1242adb0e60d56c1400e7469253ef918595675d4 Mon Sep 17 00:00:00 2001 From: hex2077 Date: Wed, 13 Aug 2025 14:41:27 +0800 Subject: [PATCH] =?UTF-8?q?feat(api):=20=E6=B7=BB=E5=8A=A0=E5=9B=9E?= =?UTF-8?q?=E8=B0=83=E5=8A=9F=E8=83=BD=E5=B9=B6=E9=87=8D=E6=9E=84API?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在生成播客任务中添加回调URL支持,任务完成后会通知指定URL - 将`generate_podcast_audio`重命名为`generate_podcast_audio_api`并返回完整结果 - 使用`asynccontextmanager`替代已弃用的`startup`/`shutdown`事件 - 改进错误处理,添加任务取消逻辑 - 更新状态接口返回更多任务详情 --- main.py | 171 ++++++++++++++++++++++++++++--------------- podcast_generator.py | 36 +++++++-- 2 files changed, 139 insertions(+), 68 deletions(-) diff --git a/main.py b/main.py index 2743943..bd89a41 100644 --- a/main.py +++ b/main.py @@ -15,8 +15,10 @@ from enum import Enum import shutil import schedule import threading +from contextlib import asynccontextmanager # 导入 asynccontextmanager +import httpx # 导入 httpx 库 -from podcast_generator import generate_podcast_audio +from podcast_generator import generate_podcast_audio_api class TaskStatus(str, Enum): PENDING = "pending" @@ -24,18 +26,52 @@ class TaskStatus(str, Enum): COMPLETED = "completed" FAILED = "failed" -app = FastAPI() +# --- 新的 Lifespan 上下文管理器 --- +# 这是替代已弃用的 on_event("startup") 和 on_event("shutdown") 的新方法 +@asynccontextmanager +async def lifespan(app: FastAPI): + # 在应用启动时运行的代码 (等同于 startup_event) + print("FastAPI app is starting up...") + + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + + # 安排清理任务每30分钟运行一次 + schedule.every(30).minutes.do(clean_output_directory) + + # 在单独的线程中启动调度器 + scheduler_thread = threading.Thread(target=run_scheduler, daemon=True) + scheduler_thread.start() + + print("FastAPI app started. Output directory cleaning is scheduled.") + + # `yield` 语句是分割点,应用在这里运行 + yield + + # 在应用关闭时运行的代码 (等同于 shutdown_event) + print("FastAPI app is shutting down...") + + # 发送信号让调度器线程停止 + stop_scheduler_event.set() + + # 等待调度器线程结束(可选,但推荐) + # 注意:在 lifespan 中,我们无法直接访问在启动部分创建的 scheduler_thread 局部变量 + # 因此,我们仍然需要一个全局事件标志来通信。 + # 线程本身是守护线程(daemon=True),如果主程序退出它也会被强制终止, + # 但优雅地停止是更好的实践。 + print("Signaled scheduler to stop. Main application will now exit.") -# Global flag to signal the scheduler thread to stop + +# 在创建 FastAPI 实例时,传入 lifespan 函数 +app = FastAPI(lifespan=lifespan) + +# 全局标志,用于通知调度器线程停止 stop_scheduler_event = threading.Event() -# Global reference for the scheduler thread -scheduler_thread = None - -# Global configuration +# 全局配置 output_dir = "output" -# Define a function to clean the output directory +# 定义一个函数来清理输出目录 def clean_output_directory(): """Removes files from the output directory that are older than 30 minutes.""" print(f"Cleaning output directory: {output_dir}") @@ -47,24 +83,24 @@ def clean_output_directory(): file_path = os.path.join(output_dir, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): - # Get last modification time + # 获取最后修改时间 if now - os.path.getmtime(file_path) > threshold: os.unlink(file_path) print(f"Deleted old file: {file_path}") elif os.path.isdir(file_path): - # Optionally, recursively delete old subdirectories or files within them - # For now, just skip directories + # 可选地,递归删除旧的子目录或其中的文件 + # 目前只跳过目录 pass except Exception as e: print(f"Failed to delete {file_path}. Reason: {e}") -# In-memory store for task results +# 内存中存储任务结果 # {task_id: {"auth_id": auth_id, "status": TaskStatus, "result": any, "timestamp": float}} task_results: Dict[str, Dict[UUID, Dict]] = {} -# Configuration for signature verification -SECRET_KEY = os.getenv("PODCAST_API_SECRET_KEY", "your-super-secret-key") # Change this in production! -# Define a mapping from tts_provider names to their config file paths +# 签名验证配置 +SECRET_KEY = os.getenv("PODCAST_API_SECRET_KEY", "your-super-secret-key") # 在生产环境中请务必修改! +# 定义从 tts_provider 名称到其配置文件路径的映射 tts_provider_map = { "index-tts": "config/index-tts.json", "doubao-tts": "config/doubao-tts.json", @@ -76,7 +112,7 @@ tts_provider_map = { async def get_auth_id(x_auth_id: str = Header(..., alias="X-Auth-Id")): """ - Dependency to get X-Auth-Id from headers. + 从头部获取 X-Auth-Id 的依赖项。 """ if not x_auth_id: raise HTTPException(status_code=400, detail="Missing X-Auth-Id header.") @@ -84,8 +120,8 @@ async def get_auth_id(x_auth_id: str = Header(..., alias="X-Auth-Id")): async def verify_signature(request: Request): """ - Verify the 'sign' parameter in the request headers or query parameters. - Expected signature format: SHA256(timestamp + SECRET_KEY) + 验证请求头或查询参数中的 'sign' 参数。 + 期望的签名格式: SHA256(timestamp + SECRET_KEY) """ timestamp = request.headers.get("X-Timestamp") or request.query_params.get("timestamp") client_sign = request.headers.get("X-Sign") or request.query_params.get("sign") @@ -95,7 +131,7 @@ async def verify_signature(request: Request): try: current_time = int(time.time()) - if abs(current_time - int(timestamp)) > 300: + if abs(current_time - int(timestamp)) > 300: # 请求在5分钟内有效 raise HTTPException(status_code=400, detail="Request expired or timestamp is too far in the future.") message = f"{timestamp}{SECRET_KEY}".encode('utf-8') @@ -118,7 +154,8 @@ async def _generate_podcast_task( tts_providers_config_content: str, podUsers_json_content: str, threads: int, - tts_provider: str + tts_provider: str, + callback_url: Optional[str] = None # 新增回调地址参数 ): task_results[auth_id][task_id]["status"] = TaskStatus.RUNNING try: @@ -131,10 +168,10 @@ async def _generate_podcast_task( actual_config_path = tts_provider_map.get(tts_provider) if not actual_config_path: - raise ValueError(f"Invalid tts_provider: {tts_provider}.") # Changed from HTTPException to ValueError + raise ValueError(f"Invalid tts_provider: {tts_provider}.") - output_filepath = await asyncio.to_thread( - generate_podcast_audio, + podcast_generation_results = await asyncio.to_thread( + generate_podcast_audio_api, args=args, config_path=actual_config_path, input_txt_content=input_txt_content.strip(), @@ -142,12 +179,44 @@ async def _generate_podcast_task( podUsers_json_content=podUsers_json_content.strip() ) task_results[auth_id][task_id]["status"] = TaskStatus.COMPLETED - task_results[auth_id][task_id]["result"] = output_filepath - print(f"\nPodcast generation completed for task {task_id}. Output file: {output_filepath}") + task_results[auth_id][task_id].update(podcast_generation_results) + print(f"\nPodcast generation completed for task {task_id}. Output file: {podcast_generation_results.get('output_audio_filepath')}") except Exception as e: task_results[auth_id][task_id]["status"] = TaskStatus.FAILED task_results[auth_id][task_id]["result"] = str(e) print(f"\nPodcast generation failed for task {task_id}: {e}") + finally: # 无论成功或失败,都尝试调用回调 + if callback_url: + print(f"Attempting to send callback for task {task_id} to {callback_url}") + callback_data = { + "task_id": str(task_id), + "auth_id": auth_id, + "task_results": task_results[auth_id][task_id], + "timestamp": time.time(), + } + + MAX_RETRIES = 3 # 定义最大重试次数 + RETRY_DELAY = 5 # 定义重试间隔(秒) + + for attempt in range(MAX_RETRIES + 1): # 尝试次数从0到MAX_RETRIES + try: + async with httpx.AsyncClient() as client: + response = await client.post(callback_url, json=callback_data, timeout=30.0) + response.raise_for_status() # 对 4xx/5xx 响应抛出异常 + print(f"Callback successfully sent for task {task_id} on attempt {attempt + 1}. Status: {response.status_code}") + break # 成功发送,跳出循环 + except httpx.RequestError as req_err: + print(f"Callback request failed for task {task_id} to {callback_url} on attempt {attempt + 1}: {req_err}") + except httpx.HTTPStatusError as http_err: + print(f"Callback received error response for task {task_id} from {callback_url} on attempt {attempt + 1}: {http_err.response.status_code} - {http_err.response.text}") + except Exception as cb_err: + print(f"An unexpected error occurred during callback for task {task_id} on attempt {attempt + 1}: {cb_err}") + + if attempt < MAX_RETRIES: + print(f"Retrying callback for task {task_id} in {RETRY_DELAY} seconds...") + await asyncio.sleep(RETRY_DELAY) + else: + print(f"Callback failed for task {task_id} after {MAX_RETRIES} attempts.") # @app.post("/generate-podcast", dependencies=[Depends(verify_signature)]) @app.post("/generate-podcast") @@ -161,16 +230,17 @@ async def generate_podcast_submission( tts_providers_config_content: str = Form(...), podUsers_json_content: str = Form(...), threads: int = Form(1), - tts_provider: str = Form("index-tts") + tts_provider: str = Form("index-tts"), + callback_url: Optional[str] = Form(None) # 新增回调地址参数 ): - # 1. Validate tts_provider + # 1. 验证 tts_provider if tts_provider not in tts_provider_map: raise HTTPException(status_code=400, detail=f"Invalid tts_provider: {tts_provider}.") - # 2. Check for existing running tasks for this auth_id + # 2. 检查此 auth_id 是否有正在运行的任务 if auth_id in task_results: for existing_task_id, existing_task_info in task_results[auth_id].items(): - if existing_task_info["status"] == TaskStatus.RUNNING or existing_task_info["status"] == TaskStatus.PENDING: + if existing_task_info["status"] in (TaskStatus.RUNNING, TaskStatus.PENDING): raise HTTPException(status_code=409, detail=f"There is already a running task (ID: {existing_task_id}) for this auth_id. Please wait for it to complete.") task_id = uuid.uuid4() @@ -179,7 +249,8 @@ async def generate_podcast_submission( task_results[auth_id][task_id] = { "status": TaskStatus.PENDING, "result": None, - "timestamp": time.time() + "timestamp": time.time(), + "callback_url": callback_url # 存储回调地址 } background_tasks.add_task( @@ -193,7 +264,8 @@ async def generate_podcast_submission( tts_providers_config_content, podUsers_json_content, threads, - tts_provider + tts_provider, + callback_url # 传递回调地址 ) return {"message": "Podcast generation started.", "task_id": task_id} @@ -207,11 +279,14 @@ async def get_podcast_status( return {"message": "No tasks found for this auth_id.", "tasks": []} all_tasks_for_auth_id = [] - for task_id, task_info in task_results[auth_id].items(): + for task_id, task_info in task_results.get(auth_id, {}).items(): all_tasks_for_auth_id.append({ "task_id": task_id, "status": task_info["status"], - "result": task_info["result"] if task_info["status"] == TaskStatus.COMPLETED else None, + "podUsers": task_info.get("podUsers"), + "output_audio_filepath": task_info.get("output_audio_filepath"), + "overview_content": task_info.get("overview_content"), + "podcast_script": task_info.get("podcast_script"), "error": task_info["result"] if task_info["status"] == TaskStatus.FAILED else None, "timestamp": task_info["timestamp"] }) @@ -226,7 +301,6 @@ async def download_podcast(file_name: str): @app.get("/get-voices") async def get_voices(tts_provider: str = "tts"): - config_path = tts_provider_map.get(tts_provider) if not config_path: raise HTTPException(status_code=400, detail=f"Invalid tts_provider: {tts_provider}.") @@ -252,35 +326,10 @@ async def read_root(): return {"message": "FastAPI server is running!"} def run_scheduler(): - """Runs the scheduler in a loop until the stop event is set.""" + """在循环中运行调度器,直到设置停止事件。""" while not stop_scheduler_event.is_set(): schedule.run_pending() - time.sleep(1) # Check every second for new jobs or stop event - -@app.on_event("startup") -async def startup_event(): - global scheduler_thread - # Ensure the output directory exists - os.makedirs(output_dir, exist_ok=True) - # Schedule the cleaning task to run every 30 minutes - schedule.every(30).minutes.do(clean_output_directory) - # Start the scheduler in a separate thread - if scheduler_thread is None or not scheduler_thread.is_alive(): - scheduler_thread = threading.Thread(target=run_scheduler, daemon=True) - scheduler_thread.start() - print("FastAPI app started. Scheduled output directory cleaning.") - -@app.on_event("shutdown") -async def shutdown_event(): - global scheduler_thread - # Signal the scheduler thread to stop - stop_scheduler_event.set() - # Wait for the scheduler thread to finish (optional, but good practice) - if scheduler_thread is not None and scheduler_thread.is_alive(): - scheduler_thread.join(timeout=5) # Wait for max 5 seconds - if scheduler_thread.is_alive(): - print("Scheduler thread did not terminate gracefully.") - print("FastAPI app shutting down.") + time.sleep(1) # 每秒检查一次新任务或停止事件 if __name__ == "__main__": import uvicorn diff --git a/podcast_generator.py b/podcast_generator.py index 2587094..00b3935 100644 --- a/podcast_generator.py +++ b/podcast_generator.py @@ -364,6 +364,7 @@ def _generate_all_audio_files(podcast_script, config_data, tts_adapter: TTSAdapt for i, item in enumerate(transcripts) } + exception_caught = None for future in as_completed(future_to_index): index = future_to_index[future] try: @@ -371,8 +372,17 @@ def _generate_all_audio_files(podcast_script, config_data, tts_adapter: TTSAdapt if result: audio_files_dict[index] = result except Exception as e: - # Re-raise the exception to propagate it to the main thread - raise RuntimeError(f"Error generating audio for item {index}: {e}") + exception_caught = RuntimeError(f"Error generating audio for item {index}: {e}") + # An error occurred, we should stop. + break + + # If we broke out of the loop due to an exception, cancel other futures. + if exception_caught: + print(f"An error occurred: {exception_caught}. Cancelling outstanding tasks.") + for f in future_to_index: + if not f.done(): + f.cancel() + raise exception_caught audio_files = [audio_files_dict[i] for i in sorted(audio_files_dict.keys())] @@ -459,7 +469,7 @@ def _initialize_tts_adapter(config_data: dict, tts_providers_config_content: Opt else: raise ValueError(f"Unsupported TTS provider: {tts_provider}") -def main(): +def generate_podcast_audio(): args = _parse_arguments() config_data = _load_configuration() api_key, base_url, model = _prepare_openai_settings(args, config_data) @@ -479,9 +489,16 @@ def main(): audio_files = _generate_all_audio_files(podcast_script, config_data, tts_adapter, args.threads) _create_ffmpeg_file_list(audio_files) + output_audio_filepath = merge_audio_files() + return { + "output_audio_filepath": output_audio_filepath, + "overview_content": overview_content, + "podcast_script": podcast_script, + "podUsers": pod_users, + } -def generate_podcast_audio(args, config_path: str, input_txt_content: str, tts_providers_config_content: str, podUsers_json_content: str) -> str: +def generate_podcast_audio_api(args, config_path: str, input_txt_content: str, tts_providers_config_content: str, podUsers_json_content: str) -> dict: """ Generates a podcast audio file based on the provided parameters. @@ -519,14 +536,19 @@ def generate_podcast_audio(args, config_path: str, input_txt_content: str, tts_p _create_ffmpeg_file_list(audio_files) output_audio_filepath = merge_audio_files() - return output_audio_filepath + task_results = { + "output_audio_filepath": output_audio_filepath, + "overview_content": overview_content, + "podcast_script": podcast_script, + "podUsers": podUsers, + } + return task_results if __name__ == "__main__": start_time = time.time() try: - main() - merge_audio_files() + generate_podcast_audio() except Exception as e: print(f"\nError: An unexpected error occurred during podcast generation: {e}", file=sys.stderr) sys.exit(1)