feat(podcast): 添加沉浸故事模式支持多语言播客生成
新增沉浸故事生成模式,支持原文朗读和智能分段: - 服务端新增generate_podcast_with_story_api函数和专用API端点 - 添加故事模式专用prompt模板(prompt-story-overview.txt和prompt-story-podscript.txt) - 前端新增模式切换UI,支持AI播客和沉浸故事两种模式 - 沉浸故事模式固定消耗30积分,不需要语言和时长参数 - 优化音频静音裁剪逻辑,保留首尾200ms空白提升自然度 - 修复session管理和错误处理,提升系统稳定性 - 新增多语言配置(中英日)支持模式切换文案
This commit is contained in:
121
server/main.py
121
server/main.py
@@ -22,7 +22,7 @@ import httpx # 导入 httpx 库
|
||||
from io import BytesIO # 导入 BytesIO
|
||||
import base64 # 导入 base64
|
||||
|
||||
from podcast_generator import generate_podcast_audio_api
|
||||
from podcast_generator import generate_podcast_audio_api, generate_podcast_with_story_api
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
@@ -214,6 +214,7 @@ async def _generate_podcast_task(
|
||||
output_language: Optional[str] = None,
|
||||
usetime: Optional[str] = None,
|
||||
lang: Optional[str] = None,
|
||||
use_story_mode: bool = False, # 新增参数,是否使用故事模式
|
||||
):
|
||||
task_results[auth_id][task_id]["status"] = TaskStatus.RUNNING
|
||||
try:
|
||||
@@ -230,14 +231,25 @@ async def _generate_podcast_task(
|
||||
if not actual_config_path:
|
||||
raise ValueError(f"Invalid tts_provider: {tts_provider}.")
|
||||
|
||||
podcast_generation_results = await asyncio.to_thread(
|
||||
generate_podcast_audio_api,
|
||||
args=args,
|
||||
config_path=actual_config_path,
|
||||
input_txt_content=input_txt_content.strip(),
|
||||
tts_providers_config_content=tts_providers_config_content.strip(),
|
||||
podUsers_json_content=podUsers_json_content.strip()
|
||||
)
|
||||
# 根据 use_story_mode 参数决定调用哪个函数
|
||||
if use_story_mode:
|
||||
podcast_generation_results = await asyncio.to_thread(
|
||||
generate_podcast_with_story_api,
|
||||
args=args,
|
||||
config_path=actual_config_path,
|
||||
input_txt_content=input_txt_content.strip(),
|
||||
tts_providers_config_content=tts_providers_config_content.strip(),
|
||||
podUsers_json_content=podUsers_json_content.strip()
|
||||
)
|
||||
else:
|
||||
podcast_generation_results = await asyncio.to_thread(
|
||||
generate_podcast_audio_api,
|
||||
args=args,
|
||||
config_path=actual_config_path,
|
||||
input_txt_content=input_txt_content.strip(),
|
||||
tts_providers_config_content=tts_providers_config_content.strip(),
|
||||
podUsers_json_content=podUsers_json_content.strip()
|
||||
)
|
||||
task_results[auth_id][task_id]["status"] = TaskStatus.COMPLETED
|
||||
task_results[auth_id][task_id].update(podcast_generation_results)
|
||||
print(f"\nPodcast generation completed for task {task_id}. Output file: {podcast_generation_results.get('output_audio_filepath')}")
|
||||
@@ -266,10 +278,11 @@ async def _generate_podcast_task(
|
||||
"task_id": str(task_id),
|
||||
"auth_id": auth_id,
|
||||
"task_results": task_results[auth_id][task_id],
|
||||
"timestamp": int(time.time()),
|
||||
"timestamp": int(time.time()),
|
||||
"status": task_results[auth_id][task_id]["status"],
|
||||
"usetime": usetime,
|
||||
"lang": lang,
|
||||
"mode": "ai-story" if use_story_mode else "normal",
|
||||
}
|
||||
|
||||
MAX_RETRIES = 3 # 定义最大重试次数
|
||||
@@ -293,7 +306,34 @@ async def _generate_podcast_task(
|
||||
print(f"Retrying callback for task {task_id} in {RETRY_DELAY} seconds...")
|
||||
await asyncio.sleep(RETRY_DELAY)
|
||||
else:
|
||||
print(f"Callback failed for task {task_id} after {MAX_RETRIES} attempts.")
|
||||
print(f"Callback failed for task {task_id} after {MAX_RETRIES} attempts.")
|
||||
|
||||
|
||||
async def _generate_podcast_with_story_task(
|
||||
task_id: UUID,
|
||||
auth_id: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
input_txt_content: str,
|
||||
tts_providers_config_content: str,
|
||||
podUsers_json_content: str,
|
||||
threads: int,
|
||||
tts_provider: str,
|
||||
callback_url: Optional[str] = None, # 新增回调地址参数
|
||||
output_language: Optional[str] = None,
|
||||
usetime: Optional[str] = None,
|
||||
lang: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
调用带优化流程的播客生成任务处理函数
|
||||
"""
|
||||
return await _generate_podcast_task(
|
||||
task_id, auth_id, api_key, base_url, model, input_txt_content,
|
||||
tts_providers_config_content, podUsers_json_content, threads,
|
||||
tts_provider, callback_url, output_language, usetime, lang,
|
||||
use_story_mode=True
|
||||
)
|
||||
|
||||
# @app.post("/generate-podcast", dependencies=[Depends(verify_signature)])
|
||||
@app.post("/generate-podcast")
|
||||
@@ -354,6 +394,65 @@ async def generate_podcast_submission(
|
||||
|
||||
return {"message": "Podcast generation started.", "task_id": task_id}
|
||||
|
||||
|
||||
@app.post("/generate-podcast-with-story")
|
||||
async def generate_podcast_with_story_submission(
|
||||
background_tasks: BackgroundTasks,
|
||||
auth_id: str = Depends(get_auth_id),
|
||||
api_key: str = Form("OpenAI API key."),
|
||||
base_url: str = Form("https://api.openai.com/v1"),
|
||||
model: str = Form("gpt-3.5-turbo"),
|
||||
input_txt_content: str = Form(...),
|
||||
tts_providers_config_content: str = Form(...),
|
||||
podUsers_json_content: str = Form(...),
|
||||
threads: int = Form(1),
|
||||
tts_provider: str = Form("index-tts"),
|
||||
callback_url: Optional[str] = Form(None),
|
||||
output_language: Optional[str] = Form(None),
|
||||
usetime: Optional[str] = Form(None),
|
||||
lang: Optional[str] = Form(None),
|
||||
):
|
||||
# 1. 验证 tts_provider
|
||||
if tts_provider not in tts_provider_map:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid tts_provider: {tts_provider}.")
|
||||
|
||||
# 2. 检查此 auth_id 是否有正在运行的任务
|
||||
if auth_id in task_results:
|
||||
for existing_task_id, existing_task_info in task_results[auth_id].items():
|
||||
if existing_task_info["status"] in (TaskStatus.RUNNING, TaskStatus.PENDING):
|
||||
raise HTTPException(status_code=409, detail=f"There is already a running task (ID: {existing_task_id}) for this auth_id. Please wait for it to complete.")
|
||||
|
||||
task_id = uuid.uuid4()
|
||||
if auth_id not in task_results:
|
||||
task_results[auth_id] = {}
|
||||
task_results[auth_id][task_id] = {
|
||||
"status": TaskStatus.PENDING,
|
||||
"result": None,
|
||||
"timestamp": time.time(),
|
||||
"callback_url": callback_url, # 存储回调地址
|
||||
"auth_id": auth_id, # 存储 auth_id
|
||||
}
|
||||
|
||||
background_tasks.add_task(
|
||||
_generate_podcast_with_story_task,
|
||||
task_id,
|
||||
auth_id,
|
||||
api_key,
|
||||
base_url,
|
||||
model,
|
||||
input_txt_content,
|
||||
tts_providers_config_content,
|
||||
podUsers_json_content,
|
||||
threads,
|
||||
tts_provider,
|
||||
callback_url,
|
||||
output_language,
|
||||
usetime,
|
||||
lang,
|
||||
)
|
||||
|
||||
return {"message": "Podcast generation with story started.", "task_id": task_id}
|
||||
|
||||
# @app.get("/podcast-status", dependencies=[Depends(verify_signature)])
|
||||
@app.get("/podcast-status")
|
||||
async def get_podcast_status(
|
||||
|
||||
Reference in New Issue
Block a user