feat(config): 添加webvoice配置支持多TTS提供商和优化播客生成流程
新增webvoice.json配置文件,包含大量语音选项,更新TTS适配器以支持多提供商配置,改进播客生成流程中的错误处理和重试机制,优化UI组件以支持新的语音选择功能
This commit is contained in:
@@ -86,6 +86,7 @@ audio_file_mapping: Dict[str, Dict] = {}
|
||||
SECRET_KEY = os.getenv("PODCAST_API_SECRET_KEY", "your-super-secret-key") # 在生产环境中请务必修改!
|
||||
# 定义从 tts_provider 名称到其配置文件路径的映射
|
||||
tts_provider_map = {
|
||||
"webvoice": "../config/webvoice.json",
|
||||
"index-tts": "../config/index-tts.json",
|
||||
"doubao-tts": "../config/doubao-tts.json",
|
||||
"edge-tts": "../config/edge-tts.json",
|
||||
|
||||
@@ -21,6 +21,19 @@ output_dir = "output"
|
||||
# file_list_path is now generated uniquely for each merge operation
|
||||
tts_providers_config_path = '../config/tts_providers.json'
|
||||
|
||||
# Global cache for TTS provider configurations
|
||||
tts_provider_configs_cache = {}
|
||||
|
||||
# Define the TTS provider map
|
||||
tts_provider_map = {
|
||||
"index-tts": "../config/index-tts.json",
|
||||
"doubao-tts": "../config/doubao-tts.json",
|
||||
"edge-tts": "../config/edge-tts.json",
|
||||
"fish-audio": "../config/fish-audio.json",
|
||||
"gemini-tts": "../config/gemini-tts.json",
|
||||
"minimax": "../config/minimax.json",
|
||||
}
|
||||
|
||||
def read_file_content(filepath):
|
||||
"""Reads content from a given file path."""
|
||||
try:
|
||||
@@ -360,18 +373,36 @@ def _load_configuration():
|
||||
print("\nLoaded Configuration: " + tts_provider)
|
||||
return config_data
|
||||
|
||||
def _load_configuration_path(config_path: str) -> dict:
|
||||
"""Loads JSON configuration from a specified path and infers tts_provider from the file name."""
|
||||
def _load_configuration_path(config_path: str, pod_users: Optional[list] = None) -> dict:
|
||||
"""Loads JSON configuration from a specified path and infers tts_provider from the file name or podUsers owner."""
|
||||
config_data = _load_json_config(config_path)
|
||||
|
||||
# 从文件名中提取 tts_provider
|
||||
# 先从文件名中提取 tts_provider
|
||||
file_name = os.path.basename(config_path)
|
||||
tts_provider = os.path.splitext(file_name)[0] # 移除 .json 扩展名
|
||||
default_tts_provider = os.path.splitext(file_name)[0] # 移除 .json 扩展名
|
||||
|
||||
# 如果提供了 pod_users 参数,则使用它;否则从配置中获取
|
||||
if pod_users is None:
|
||||
pod_users = config_data.get("podUsers", [])
|
||||
|
||||
# 从 podUsers 中获取所有不同的 owner 值,用逗号分隔
|
||||
owners = []
|
||||
if pod_users: # 添加空值检查
|
||||
owners = list(set(user.get("owner") for user in pod_users if user.get("owner")))
|
||||
|
||||
# 如果找到了 owners,则使用逗号分隔的 owners 作为 tts_provider
|
||||
if owners:
|
||||
tts_provider = ",".join(owners)
|
||||
print(f"Found multiple owners in podUsers: {owners}. Using comma-separated tts_provider: {tts_provider}")
|
||||
else:
|
||||
# 否则使用默认的从文件名提取的 tts_provider
|
||||
tts_provider = "edge-tts"
|
||||
print(f"No owners found in podUsers. Using default tts_provider from file name: {tts_provider}")
|
||||
|
||||
config_data["tts_provider"] = tts_provider # 将 tts_provider 添加到配置数据中
|
||||
|
||||
print(f"\nLoaded Configuration: {tts_provider} from {config_path}")
|
||||
return config_data
|
||||
return config_data
|
||||
|
||||
def _prepare_openai_settings(args, config_data):
|
||||
"""Determines final OpenAI API key, base URL, and model based on priority."""
|
||||
@@ -428,86 +459,180 @@ def _prepare_podcast_prompts(config_data, original_podscript_prompt, custom_cont
|
||||
podscript_prompt = speaker_id_info + "\n\n" + custom_content + "\n\n" + original_podscript_prompt
|
||||
return podscript_prompt, pod_users, voices, turn_pattern # Return voices for potential future use or consistency
|
||||
|
||||
def _is_content_quality_acceptable(content: str, title: str, tags: str, content_type: str = "overview") -> bool:
|
||||
"""Checks if the generated content meets quality standards."""
|
||||
if content_type == "overview":
|
||||
# Check if overview content is not empty and has reasonable length
|
||||
if not content or len(content.strip()) < 20:
|
||||
return False
|
||||
if not title or len(title.strip()) < 2:
|
||||
return False
|
||||
if not tags or len(tags.strip()) < 1:
|
||||
return False
|
||||
return True
|
||||
elif content_type == "script":
|
||||
try:
|
||||
# Check if the content contains valid podcast script JSON with transcripts
|
||||
podcast_script = json.loads(content)
|
||||
if "podcast_transcripts" not in podcast_script:
|
||||
return False
|
||||
transcripts = podcast_script.get("podcast_transcripts", [])
|
||||
if not transcripts or len(transcripts) == 0:
|
||||
return False
|
||||
# Check if transcripts have required fields (speaker_id and dialog)
|
||||
for transcript in transcripts:
|
||||
if "speaker_id" not in transcript or "dialog" not in transcript:
|
||||
return False
|
||||
dialog = transcript.get("dialog", "").strip()
|
||||
if not dialog or len(dialog) < 1:
|
||||
return False
|
||||
return True
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def _generate_overview_content(api_key, base_url, model, overview_prompt, input_prompt, output_language: Optional[str] = None) -> Tuple[str, str, str]:
|
||||
"""Generates overview content using OpenAI CLI, and extracts title and tags."""
|
||||
print(f"\nGenerating overview with OpenAI CLI (Output Language: {output_language})...")
|
||||
try:
|
||||
# Replace the placeholder with the actual output language
|
||||
formatted_overview_prompt = overview_prompt.replace("{{outlang}}", output_language if output_language is not None else "Make sure the input language is set as the output language")
|
||||
|
||||
openai_client_overview = OpenAICli(api_key=api_key, base_url=base_url, model=model, system_message=formatted_overview_prompt)
|
||||
overview_response_generator = openai_client_overview.chat_completion(messages=[{"role": "user", "content": input_prompt}])
|
||||
overview_content = "".join([chunk.choices[0].delta.content for chunk in overview_response_generator if chunk.choices and chunk.choices[0].delta.content])
|
||||
|
||||
# Extract title (first line) and tags (second line)
|
||||
lines = overview_content.strip().split('\n')
|
||||
title = lines[0].strip() if len(lines) > 0 else ""
|
||||
tags = ""
|
||||
# 重复判断3次是否有非空值,没有值就取下一行
|
||||
for i in range(1, min(len(lines), 4)): # 检查第2到第4行 (索引1到3)
|
||||
current_tags = lines[i].strip()
|
||||
if current_tags:
|
||||
tags = current_tags
|
||||
# 保留取到tags的索引行,从下一行开始截取到最后一行,保存数据到overview_content
|
||||
overview_content = "\n".join(lines[i+1:]).strip()
|
||||
break
|
||||
else: # 如果循环结束没有找到非空tags,则从第二行开始截取
|
||||
overview_content = "\n".join(lines[1:]).strip()
|
||||
max_retries = 3
|
||||
attempt = 0
|
||||
|
||||
print(f"Extracted Title: {title}")
|
||||
print(f"Extracted Tags: {tags}")
|
||||
print("Generated Overview:")
|
||||
print(overview_content[:100])
|
||||
|
||||
return overview_content, title, tags
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error generating overview: {e}")
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
# Replace the placeholder with the actual output language
|
||||
formatted_overview_prompt = overview_prompt.replace("{{outlang}}", output_language if output_language is not None else "Make sure the input language is set as the output language")
|
||||
|
||||
openai_client_overview = OpenAICli(api_key=api_key, base_url=base_url, model=model, system_message=formatted_overview_prompt)
|
||||
overview_response_generator = openai_client_overview.chat_completion(messages=[{"role": "user", "content": input_prompt}])
|
||||
overview_content = "".join([chunk.choices[0].delta.content for chunk in overview_response_generator if chunk.choices and chunk.choices[0].delta.content])
|
||||
|
||||
# Extract title (first line) and tags (second line)
|
||||
lines = overview_content.strip().split('\n')
|
||||
title = lines[0].strip() if len(lines) > 0 else ""
|
||||
tags = ""
|
||||
# 重复判断3次是否有非空值,没有值就取下一行
|
||||
for i in range(1, min(len(lines), 4)): # 检查第2到第4行 (索引1到3)
|
||||
current_tags = lines[i].strip()
|
||||
if current_tags:
|
||||
tags = current_tags
|
||||
# 保留取到tags的索引行,从下一行开始截取到最后一行,保存数据到overview_content
|
||||
overview_content = "\n".join(lines[i+1:]).strip()
|
||||
break
|
||||
else: # 如果循环结束没有找到非空tags,则从第二行开始截取
|
||||
overview_content = "\n".join(lines[1:]).strip()
|
||||
|
||||
# Check if the generated content meets quality standards
|
||||
if _is_content_quality_acceptable(overview_content, title, tags, "overview"):
|
||||
print(f"Generated overview content meets quality standards on attempt {attempt + 1}")
|
||||
print(f"Extracted Title: {title}")
|
||||
print(f"Extracted Tags: {tags}")
|
||||
print("Generated Overview:")
|
||||
print(overview_content[:100])
|
||||
|
||||
return overview_content, title, tags
|
||||
else:
|
||||
print(f"Generated overview content did not meet quality standards, attempt {attempt + 1}/{max_retries}")
|
||||
attempt += 1
|
||||
if attempt >= max_retries:
|
||||
raise RuntimeError(f"Failed to generate acceptable overview content after {max_retries} attempts. Content may be too short or missing required elements.")
|
||||
else:
|
||||
print(f"Retrying overview generation...")
|
||||
continue
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
if attempt >= max_retries:
|
||||
raise RuntimeError(f"Error generating overview after {max_retries} attempts: {e}")
|
||||
else:
|
||||
print(f"Attempt {attempt}/{max_retries} failed: {e}. Retrying...")
|
||||
time.sleep(1 * attempt) # Exponential backoff
|
||||
|
||||
def _generate_podcast_script(api_key, base_url, model, podscript_prompt, overview_content):
|
||||
"""Generates and parses podcast script JSON using OpenAI CLI."""
|
||||
print("\nGenerating podcast script with OpenAI CLI...")
|
||||
# Initialize podscript_json_str outside try block to ensure it's always defined
|
||||
podscript_json_str = ""
|
||||
try:
|
||||
openai_client_podscript = OpenAICli(api_key=api_key, base_url=base_url, model=model, system_message=podscript_prompt)
|
||||
# Generate the response string first
|
||||
podscript_json_str = "".join([chunk.choices[0].delta.content for chunk in openai_client_podscript.chat_completion(messages=[{"role": "user", "content": overview_content}]) if chunk.choices and chunk.choices[0].delta.content])
|
||||
|
||||
podcast_script = None
|
||||
decoder = json.JSONDecoder()
|
||||
idx = 0
|
||||
valid_json_str = ""
|
||||
max_retries = 3
|
||||
attempt = 0
|
||||
|
||||
while idx < len(podscript_json_str):
|
||||
try:
|
||||
obj, end = decoder.raw_decode(podscript_json_str[idx:])
|
||||
if isinstance(obj, dict) and "podcast_transcripts" in obj:
|
||||
podcast_script = obj
|
||||
valid_json_str = podscript_json_str[idx : idx + end]
|
||||
break
|
||||
idx += end
|
||||
except json.JSONDecodeError:
|
||||
idx += 1
|
||||
next_brace = podscript_json_str.find('{', idx)
|
||||
if next_brace != -1:
|
||||
idx = next_brace
|
||||
while attempt < max_retries:
|
||||
# Initialize podscript_json_str outside try block to ensure it's always defined
|
||||
podscript_json_str = ""
|
||||
try:
|
||||
openai_client_podscript = OpenAICli(api_key=api_key, base_url=base_url, model=model, system_message=podscript_prompt)
|
||||
# Generate the response string first
|
||||
podscript_json_str = "".join([chunk.choices[0].delta.content for chunk in openai_client_podscript.chat_completion(messages=[{"role": "user", "content": overview_content}]) if chunk.choices and chunk.choices[0].delta.content])
|
||||
|
||||
podcast_script = None
|
||||
decoder = json.JSONDecoder()
|
||||
idx = 0
|
||||
valid_json_str = ""
|
||||
|
||||
while idx < len(podscript_json_str):
|
||||
try:
|
||||
obj, end = decoder.raw_decode(podscript_json_str[idx:])
|
||||
if isinstance(obj, dict) and "podcast_transcripts" in obj:
|
||||
podcast_script = obj
|
||||
valid_json_str = podscript_json_str[idx : idx + end]
|
||||
break
|
||||
idx += end
|
||||
except json.JSONDecodeError:
|
||||
idx += 1
|
||||
next_brace = podscript_json_str.find('{', idx)
|
||||
if next_brace != -1:
|
||||
idx = next_brace
|
||||
else:
|
||||
break
|
||||
|
||||
if podcast_script is None:
|
||||
print(f"Could not find a valid podcast script JSON object with 'podcast_transcripts' key in response, attempt {attempt + 1}/{max_retries}")
|
||||
attempt += 1
|
||||
if attempt >= max_retries:
|
||||
raise ValueError(f"Error: Could not find a valid podcast script JSON object with 'podcast_transcripts' key in response. Raw response: {podscript_json_str}")
|
||||
else:
|
||||
break
|
||||
print(f"Retrying podcast script generation...")
|
||||
continue
|
||||
|
||||
if podcast_script is None:
|
||||
raise ValueError(f"Error: Could not find a valid podcast script JSON object with 'podcast_transcripts' key in response. Raw response: {podscript_json_str}")
|
||||
print("\nGenerated Podcast Script Length:"+ str(len(podcast_script.get("podcast_transcripts") or [])))
|
||||
print(valid_json_str[:100] + "...")
|
||||
|
||||
print("\nGenerated Podcast Script Length:"+ str(len(podcast_script.get("podcast_transcripts") or [])))
|
||||
print(valid_json_str[:100] + "...")
|
||||
if not podcast_script.get("podcast_transcripts"):
|
||||
raise ValueError("Error: 'podcast_transcripts' array is empty or not found in the generated script. Nothing to convert to audio.")
|
||||
return podcast_script
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Error decoding JSON from podcast script response: {e}. Raw response: {podscript_json_str}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error generating podcast script: {e}")
|
||||
if not podcast_script.get("podcast_transcripts"):
|
||||
print(f"'podcast_transcripts' array is empty or not found in the generated script, attempt {attempt + 1}/{max_retries}")
|
||||
attempt += 1
|
||||
if attempt >= max_retries:
|
||||
raise ValueError("Error: 'podcast_transcripts' array is empty or not found in the generated script. Nothing to convert to audio.")
|
||||
else:
|
||||
print(f"Retrying podcast script generation...")
|
||||
continue
|
||||
|
||||
def generate_audio_for_item(item, config_data, tts_adapter: TTSAdapter, max_retries: int = 3):
|
||||
# Check if the generated script meets quality standards
|
||||
if _is_content_quality_acceptable(valid_json_str, "", "", "script"):
|
||||
print(f"Generated podcast script meets quality standards on attempt {attempt + 1}")
|
||||
return podcast_script
|
||||
else:
|
||||
print(f"Generated podcast script did not meet quality standards, attempt {attempt + 1}/{max_retries}")
|
||||
attempt += 1
|
||||
if attempt >= max_retries:
|
||||
raise ValueError(f"Failed to generate acceptable podcast script after {max_retries} attempts. Script may be missing required elements.")
|
||||
else:
|
||||
print(f"Retrying podcast script generation...")
|
||||
continue
|
||||
except json.JSONDecodeError as e:
|
||||
attempt += 1
|
||||
if attempt >= max_retries:
|
||||
raise ValueError(f"Error decoding JSON from podcast script response: {e}. Raw response: {podscript_json_str}")
|
||||
else:
|
||||
print(f"JSON decode error on attempt {attempt}: {e}. Retrying...")
|
||||
time.sleep(1 * attempt) # Exponential backoff
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
if attempt >= max_retries:
|
||||
raise RuntimeError(f"Error generating podcast script after {max_retries} attempts: {e}")
|
||||
else:
|
||||
print(f"Attempt {attempt}/{max_retries} failed: {e}. Retrying...")
|
||||
time.sleep(1 * attempt) # Exponential backoff
|
||||
|
||||
def generate_audio_for_item(item, config_data, tts_adapter, max_retries: int = 3):
|
||||
"""Generate audio for a single podcast transcript item using the provided TTS adapter."""
|
||||
speaker_id = item.get("speaker_id")
|
||||
dialog = item.get("dialog")
|
||||
@@ -515,11 +640,14 @@ def generate_audio_for_item(item, config_data, tts_adapter: TTSAdapter, max_retr
|
||||
voice_code = None
|
||||
volume_adjustment = 0.0 # 默认值为 0.0
|
||||
speed_adjustment = 0.0 # 默认值为 0.0
|
||||
voice_tts_provider = None # 默认使用主要的 TTS 提供商
|
||||
|
||||
|
||||
if config_data and "podUsers" in config_data and 0 <= speaker_id < len(config_data["podUsers"]):
|
||||
pod_user_entry = config_data["podUsers"][speaker_id]
|
||||
voice_code = pod_user_entry.get("code")
|
||||
voice_tts_provider = pod_user_entry.get("owner") # 获取特定于该说话者的 TTS 提供商
|
||||
|
||||
# 从 voices 列表中获取对应的 volume_adjustment
|
||||
voice_map = {voice.get("code"): voice for voice in config_data.get("voices", []) if voice.get("code")}
|
||||
volume_adjustment = voice_map.get(voice_code, {}).get("volume_adjustment", 0.0)
|
||||
@@ -527,15 +655,17 @@ def generate_audio_for_item(item, config_data, tts_adapter: TTSAdapter, max_retr
|
||||
|
||||
if not voice_code:
|
||||
raise ValueError(f"No voice code found for speaker_id {speaker_id}. Cannot generate audio for this dialog.")
|
||||
|
||||
|
||||
# 如果 tts_adapter 是映射对象,则根据 voice_tts_provider 选择对应的适配器
|
||||
selected_adapter = tts_adapter[voice_tts_provider]
|
||||
# print(f"dialog-before: {dialog}")
|
||||
dialog = re.sub(r'[^\w\s\-,,.。??!!\u4e00-\u9fa5]', '', dialog)
|
||||
print(f"dialog: {dialog}")
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
print(f"Calling TTS API for speaker {speaker_id} ({voice_code}) (Attempt {attempt + 1}/{max_retries})...")
|
||||
temp_audio_file = tts_adapter.generate_audio(
|
||||
print(f"Calling TTS API for speaker {speaker_id} ({voice_code}) with adapter (Attempt {attempt + 1}/{max_retries})...")
|
||||
temp_audio_file = selected_adapter.generate_audio(
|
||||
text=dialog,
|
||||
voice_code=voice_code,
|
||||
output_dir=output_dir,
|
||||
@@ -554,7 +684,7 @@ def generate_audio_for_item(item, config_data, tts_adapter: TTSAdapter, max_retr
|
||||
except Exception as e: # Catch other unexpected errors
|
||||
raise RuntimeError(f"An unexpected error occurred for speaker {speaker_id} ({voice_code}) on attempt {attempt + 1}: {e}")
|
||||
|
||||
def _generate_all_audio_files(podcast_script, config_data, tts_adapter: TTSAdapter, threads):
|
||||
def _generate_all_audio_files(podcast_script, config_data, tts_adapter, threads):
|
||||
"""Orchestrates the generation of individual audio files."""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
print("\nGenerating audio files...")
|
||||
@@ -634,15 +764,41 @@ def _create_ffmpeg_file_list(audio_files, expected_count: int):
|
||||
|
||||
from typing import cast # Add import for cast
|
||||
|
||||
def _initialize_tts_adapter(config_data: dict, tts_providers_config_content: Optional[str] = None) -> TTSAdapter:
|
||||
|
||||
def initialize_tts_provider_configs():
|
||||
"""
|
||||
根据配置数据初始化并返回相应的 TTS 适配器。
|
||||
初始化并缓存所有 TTS 提供商的配置
|
||||
"""
|
||||
global tts_provider_configs_cache
|
||||
global tts_provider_map
|
||||
|
||||
# 清空现有缓存
|
||||
tts_provider_configs_cache = {}
|
||||
|
||||
# 加载预定义映射中的配置文件
|
||||
for provider, config_path in tts_provider_map.items():
|
||||
try:
|
||||
config_data = _load_json_config(config_path)
|
||||
tts_provider_configs_cache[provider] = config_data # 例如 'doubao-tts' -> 'doubao'
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Configuration file not found for {provider}: {config_path}")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Warning: Invalid JSON in configuration file for {provider}: {config_path}, Error: {e}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load configuration for {provider}: {config_path}, Error: {e}")
|
||||
|
||||
def _initialize_tts_adapter(config_data: dict, tts_providers_config_content: Optional[str] = None) -> dict:
|
||||
"""
|
||||
根据配置数据初始化并返回相应的 TTS 适配器映射对象。
|
||||
支持逗号分隔的 tts_provider 值,返回每个 provider 对应的适配器映射对象
|
||||
"""
|
||||
tts_provider = config_data.get("tts_provider")
|
||||
if not tts_provider:
|
||||
raise ValueError("TTS provider is not specified in the configuration.")
|
||||
|
||||
# 如果缓存为空,则初始化缓存
|
||||
if not tts_provider_configs_cache:
|
||||
initialize_tts_provider_configs()
|
||||
|
||||
tts_providers_config = {}
|
||||
try:
|
||||
if tts_providers_config_content:
|
||||
@@ -653,50 +809,64 @@ def _initialize_tts_adapter(config_data: dict, tts_providers_config_content: Opt
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load tts_providers.json: {e}")
|
||||
|
||||
# 获取当前 tts_provider 的额外参数
|
||||
current_tts_extra_params = tts_providers_config.get(tts_provider.split('-')[0], {}) # 例如 'doubao-tts' -> 'doubao'
|
||||
# 支持逗号分隔的 tts_provider
|
||||
providers = [provider.strip() for provider in tts_provider.split(',')]
|
||||
|
||||
adapters_map = {}
|
||||
for provider in providers:
|
||||
# 从缓存中获取当前 tts_provider 的额外参数
|
||||
current_tts_config_params = tts_provider_configs_cache.get(provider, {})
|
||||
current_tts_extra_params = tts_providers_config.get(provider.split('-')[0], {}) # 例如 'doubao-tts' -> 'doubao'
|
||||
|
||||
if tts_provider == "index-tts":
|
||||
api_url = config_data.get("apiUrl")
|
||||
if not api_url:
|
||||
raise ValueError("IndexTTS apiUrl is not configured.")
|
||||
return IndexTTSAdapter(api_url_template=cast(str, api_url), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
elif tts_provider == "edge-tts":
|
||||
api_url = config_data.get("apiUrl")
|
||||
if not api_url:
|
||||
raise ValueError("EdgeTTS apiUrl is not configured.")
|
||||
return EdgeTTSAdapter(api_url_template=cast(str, api_url), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
if provider == "index-tts":
|
||||
# 优先从 config_data 获取,如果没有则从缓存中获取
|
||||
api_url = config_data.get("apiUrl") or current_tts_config_params.get("apiUrl")
|
||||
if not api_url:
|
||||
raise ValueError("IndexTTS apiUrl is not configured.")
|
||||
adapters_map[provider] = IndexTTSAdapter(api_url_template=cast(str, api_url), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
elif provider == "edge-tts":
|
||||
# 优先从 config_data 获取,如果没有则从缓存中获取
|
||||
api_url = config_data.get("apiUrl") or current_tts_config_params.get("apiUrl")
|
||||
if not api_url:
|
||||
raise ValueError("EdgeTTS apiUrl is not configured.")
|
||||
adapters_map[provider] = EdgeTTSAdapter(api_url_template=cast(str, api_url), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
|
||||
elif tts_provider == "fish-audio":
|
||||
api_url = config_data.get("apiUrl")
|
||||
headers = config_data.get("headers")
|
||||
request_payload = config_data.get("request_payload")
|
||||
if not all([api_url, headers, request_payload]):
|
||||
raise ValueError("FishAudio requires apiUrl, headers, and request_payload configuration.")
|
||||
return FishAudioAdapter(api_url=cast(str, api_url), headers=cast(dict, headers), request_payload_template=cast(dict, request_payload), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
elif tts_provider == "minimax":
|
||||
api_url = config_data.get("apiUrl")
|
||||
headers = config_data.get("headers")
|
||||
request_payload = config_data.get("request_payload")
|
||||
if not all([api_url, headers, request_payload]):
|
||||
raise ValueError("Minimax requires apiUrl, headers, and request_payload configuration.")
|
||||
return MinimaxAdapter(api_url=cast(str, api_url), headers=cast(dict, headers), request_payload_template=cast(dict, request_payload), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
elif tts_provider == "doubao-tts":
|
||||
api_url = config_data.get("apiUrl")
|
||||
headers = config_data.get("headers")
|
||||
request_payload = config_data.get("request_payload")
|
||||
if not all([api_url, headers, request_payload]):
|
||||
raise ValueError("DoubaoTTS requires apiUrl, headers, and request_payload configuration.")
|
||||
return DoubaoTTSAdapter(api_url=cast(str, api_url), headers=cast(dict, headers), request_payload_template=cast(dict, request_payload), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
elif tts_provider == "gemini-tts":
|
||||
api_url = config_data.get("apiUrl")
|
||||
headers = config_data.get("headers")
|
||||
request_payload = config_data.get("request_payload")
|
||||
if not all([api_url, headers, request_payload]):
|
||||
raise ValueError("GeminiTTS requires apiUrl, headers, and request_payload configuration.")
|
||||
return GeminiTTSAdapter(api_url=cast(str, api_url), headers=cast(dict, headers), request_payload_template=cast(dict, request_payload), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
else:
|
||||
raise ValueError(f"Unsupported TTS provider: {tts_provider}")
|
||||
elif provider == "fish-audio":
|
||||
# 优先从 config_data 获取,如果没有则从缓存中获取
|
||||
api_url = config_data.get("apiUrl") or current_tts_config_params.get("apiUrl")
|
||||
headers = config_data.get("headers") or current_tts_config_params.get("headers")
|
||||
request_payload = config_data.get("request_payload") or current_tts_config_params.get("request_payload")
|
||||
if not all([api_url, headers, request_payload]):
|
||||
raise ValueError("FishAudio requires apiUrl, headers, and request_payload configuration.")
|
||||
adapters_map[provider] = FishAudioAdapter(api_url=cast(str, api_url), headers=cast(dict, headers), request_payload_template=cast(dict, request_payload), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
elif provider == "minimax":
|
||||
# 优先从 config_data 获取,如果没有则从缓存中获取
|
||||
api_url = config_data.get("apiUrl") or current_tts_config_params.get("apiUrl")
|
||||
headers = config_data.get("headers") or current_tts_config_params.get("headers")
|
||||
request_payload = config_data.get("request_payload") or current_tts_config_params.get("request_payload")
|
||||
if not all([api_url, headers, request_payload]):
|
||||
raise ValueError("Minimax requires apiUrl, headers, and request_payload configuration.")
|
||||
adapters_map[provider] = MinimaxAdapter(api_url=cast(str, api_url), headers=cast(dict, headers), request_payload_template=cast(dict, request_payload), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
elif provider == "doubao-tts":
|
||||
# 优先从 config_data 获取,如果没有则从缓存中获取
|
||||
api_url = config_data.get("apiUrl") or current_tts_config_params.get("apiUrl")
|
||||
headers = config_data.get("headers") or current_tts_config_params.get("headers")
|
||||
request_payload = config_data.get("request_payload") or current_tts_config_params.get("request_payload")
|
||||
if not all([api_url, headers, request_payload]):
|
||||
raise ValueError("DoubaoTTS requires apiUrl, headers, and request_payload configuration.")
|
||||
adapters_map[provider] = DoubaoTTSAdapter(api_url=cast(str, api_url), headers=cast(dict, headers), request_payload_template=cast(dict, request_payload), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
elif provider == "gemini-tts":
|
||||
# 优先从 config_data 获取,如果没有则从缓存中获取
|
||||
api_url = config_data.get("apiUrl") or current_tts_config_params.get("apiUrl")
|
||||
headers = config_data.get("headers") or current_tts_config_params.get("headers")
|
||||
request_payload = config_data.get("request_payload") or current_tts_config_params.get("request_payload")
|
||||
if not all([api_url, headers, request_payload]):
|
||||
raise ValueError("GeminiTTS requires apiUrl, headers, and request_payload configuration.")
|
||||
adapters_map[provider] = GeminiTTSAdapter(api_url=cast(str, api_url), headers=cast(dict, headers), request_payload_template=cast(dict, request_payload), tts_extra_params=cast(dict, current_tts_extra_params))
|
||||
else:
|
||||
raise ValueError(f"Unsupported TTS provider: {provider}")
|
||||
|
||||
return adapters_map
|
||||
|
||||
def generate_podcast_audio():
|
||||
args = _parse_arguments()
|
||||
@@ -714,7 +884,7 @@ def generate_podcast_audio():
|
||||
overview_content, title, tags = _generate_overview_content(api_key, base_url, model, overview_prompt, input_prompt, args.output_language)
|
||||
podcast_script = _generate_podcast_script(api_key, base_url, model, podscript_prompt, overview_content)
|
||||
|
||||
tts_adapter = _initialize_tts_adapter(config_data) # 初始化 TTS 适配器
|
||||
tts_adapter = _initialize_tts_adapter(config_data) # 初始化 TTS 适配器,现在返回适配器映射
|
||||
|
||||
audio_files = _generate_all_audio_files(podcast_script, config_data, tts_adapter, args.threads)
|
||||
file_list_path_created = _create_ffmpeg_file_list(audio_files, len(podcast_script.get("podcast_transcripts", [])))
|
||||
@@ -744,8 +914,8 @@ def generate_podcast_audio_api(args, config_path: str, input_txt_content: str, t
|
||||
str: The path to the generated audio file.
|
||||
"""
|
||||
print("Starting podcast audio generation...")
|
||||
config_data = _load_configuration_path(config_path)
|
||||
podUsers = json.loads(podUsers_json_content)
|
||||
config_data = _load_configuration_path(config_path, podUsers)
|
||||
config_data["podUsers"] = podUsers
|
||||
|
||||
final_api_key, final_base_url, final_model = _prepare_openai_settings(args, config_data)
|
||||
@@ -761,7 +931,7 @@ def generate_podcast_audio_api(args, config_path: str, input_txt_content: str, t
|
||||
overview_content, title, tags = _generate_overview_content(final_api_key, final_base_url, final_model, overview_prompt, input_prompt, args.output_language)
|
||||
podcast_script = _generate_podcast_script(final_api_key, final_base_url, final_model, podscript_prompt, overview_content)
|
||||
|
||||
tts_adapter = _initialize_tts_adapter(config_data, tts_providers_config_content) # 初始化 TTS 适配器
|
||||
tts_adapter = _initialize_tts_adapter(config_data, tts_providers_config_content) # 初始化 TTS 适配器,现在返回适配器映射
|
||||
|
||||
audio_files = _generate_all_audio_files(podcast_script, config_data, tts_adapter, args.threads)
|
||||
file_list_path_created = _create_ffmpeg_file_list(audio_files, len(podcast_script.get("podcast_transcripts", [])))
|
||||
@@ -787,6 +957,9 @@ def generate_podcast_audio_api(args, config_path: str, input_txt_content: str, t
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize TTS provider configs cache at startup
|
||||
initialize_tts_provider_configs()
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
generate_podcast_audio()
|
||||
|
||||
Reference in New Issue
Block a user