feat: 添加IndexTTS语音检查脚本和API服务
- 新增check_indextts_voices.py脚本用于验证IndexTTS语音配置 - 实现index-tts-api.py作为IndexTTS的FastAPI封装服务 - 添加示例音频文件和更新README文档
This commit is contained in:
219
README.md
Normal file
219
README.md
Normal file
@@ -0,0 +1,219 @@
|
||||
# 🎙️ 简易播客生成器 (Simple Podcast Generator)
|
||||
|
||||
> 轻松将您的想法,一键生成为生动有趣的多人对话播客!
|
||||
|
||||
这是一个强大的脚本工具,它利用 **OpenAI API** 的智慧生成富有洞察力的播客脚本,并通过 **TTS (Text-to-Speech)** API服务,将冰冷的文字转化为有温度的音频。您只需提供一个主题,剩下的交给它!
|
||||
|
||||
✨ 本项目的播客脚本生成逻辑深受 [SurfSense](https://github.com/MODSetter/SurfSense) 项目的启发,在此向其开源贡献表示衷心感谢!
|
||||
|
||||
---
|
||||
|
||||
## ✨ 核心功能
|
||||
|
||||
* **🤖 AI 驱动脚本**:借助强大的 OpenAI 模型,自动创作高质量、有深度的播客对话脚本。
|
||||
* **👥 多角色支持**:自由定义多个播客角色(如主持、嘉宾),并为每个角色指定独一无二的 TTS 语音。
|
||||
* **🔌 灵活的 TTS 集成**:通过简单的 API URL 配置,无缝对接您自建的或第三方的 TTS 服务。
|
||||
* **🔊 智能音频合并**:自动将各个角色的语音片段精准拼接,合成一个完整的、流畅的播客音频文件 (`.wav` 格式)。
|
||||
* **⌨️ 便捷的命令行接口**:提供清晰的命令行参数,让您对播客生成过程的每一个环节都了如指掌。
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ 安装指南
|
||||
|
||||
### 📝 前提条件
|
||||
|
||||
1. **Python 3.x**
|
||||
* 请确保您的系统中已安装 Python 3。
|
||||
|
||||
2. **FFmpeg**
|
||||
* 本项目依赖 FFmpeg 进行音频合并。请访问 [FFmpeg 官网](https://ffmpeg.org/download.html) 下载并安装。
|
||||
* **重要提示**:安装完成后,请确保 `ffmpeg` 命令已添加到您系统的环境变量 (PATH) 中,以便脚本可以正常调用。
|
||||
|
||||
### 🐍 Python 依赖
|
||||
|
||||
打开您的终端或命令提示符,使用 pip 安装所需的 Python 库:
|
||||
```bash
|
||||
pip install requests openai
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 1. 准备输入文件
|
||||
|
||||
在运行前,请确保以下文件已准备就绪:
|
||||
|
||||
* `input.txt`: 在此文件中输入您想讨论的**播客主题**或核心观点。
|
||||
* `prompt/prompt-overview.txt`: 用于指导 AI 生成播客**整体大纲**的系统提示。
|
||||
* `prompt/prompt-podscript.txt`: 用于指导 AI 生成**详细对话脚本**的系统提示。它包含动态占位符(如 `{{numSpeakers}}`, `{{turnPattern}}`),脚本会自动替换。
|
||||
|
||||
### 2. 配置 TTS 服务与角色
|
||||
|
||||
* `config/` 目录下存放您的 TTS 配置文件(例如 `edge-tts.json`)。该文件定义了 TTS 服务的 API 接口、播客角色 (`podUsers`) 及其对应的语音 (`voices`)。
|
||||
|
||||
### 3. 运行脚本
|
||||
|
||||
在项目根目录下执行以下命令:
|
||||
|
||||
```bash
|
||||
python podcast_generator.py [可选参数]
|
||||
```
|
||||
|
||||
#### **可选参数**
|
||||
|
||||
* `--api-key <YOUR_OPENAI_API_KEY>`: 您的 OpenAI API 密钥。若不提供,将从配置文件或 `OPENAI_API_KEY` 环境变量中读取。
|
||||
* `--base-url <YOUR_OPENAI_BASE_URL>`: OpenAI API 的代理地址。若不提供,将从配置文件或 `OPENAI_BASE_URL` 环境变量中读取。
|
||||
* `--model <OPENAI_MODEL_NAME>`: 指定使用的 OpenAI 模型(如 `gpt-4o`, `gpt-4-turbo`)。默认值为 `gpt-3.5-turbo`。
|
||||
* `--threads <NUMBER_OF_THREADS>`: 指定生成音频的并行线程数(默认为 `1`),提高处理速度。
|
||||
|
||||
#### **运行示例**
|
||||
|
||||
```bash
|
||||
# 使用 gpt-4o 模型和 4 个线程来生成播客
|
||||
python podcast_generator.py --api-key sk-xxxxxx --model gpt-4o --threads 4
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
---
|
||||
|
||||
### 4. 自定义 AI 提示词(`custom` 代码块)
|
||||
|
||||
为了提供更细致的 AI 指令或添加特定上下文,您可以在 `input.txt` 文件中嵌入 `custom` 代码块。此代码块中的内容将作为额外指示,被内置到播客脚本生成的核心提示词(`prompt-podscript.txt`)之中,从而影响 AI 的生成行为。
|
||||
|
||||
**使用方法**:
|
||||
在 `input.txt` 文件的任意位置,使用以下格式定义您的自定义内容:
|
||||
|
||||
```
|
||||
```custom-begin
|
||||
您希望提供给 AI 的额外指令或上下文,例如:
|
||||
- “请确保讨论中包含对 [特定概念] 的深入分析。”
|
||||
- “请在对话中加入一些幽默元素,特别是关于 [某个主题] 的笑话。”
|
||||
- “所有角色的发言都必须是简短的,并且每句话不超过两行。”
|
||||
```custom-end
|
||||
```
|
||||
|
||||
**效果**:
|
||||
`custom` 代码块中的所有文本内容(不包括 `custom-begin` 和 `custom-end` 标签本身)会被提取出来,并附加到 [`prompt/prompt-podscript.txt`](prompt/prompt-podscript.txt) 模板处理后的内容之中。这意味着,这些自定义指令将直接影响 AI 在生成具体播客对话脚本时的决策和风格,帮助您更精准地控制输出。
|
||||
|
||||
**示例场景**:
|
||||
如果您希望 AI 在讨论一个技术话题时,特别强调某个技术趋势的未来发展,您可以在 `input.txt` 中添加:
|
||||
|
||||
```
|
||||
```custom-begin
|
||||
请在讨论中预见性地分析人工智能在未来五年内可能带来的颠覆性变革,并提及量子计算对现有加密技术的潜在影响。
|
||||
```custom-end
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ 配置文件详解 (`config/*.json`)
|
||||
|
||||
配置文件是整个项目的“大脑”,它告诉脚本如何与 AI 和 TTS 服务协同工作。
|
||||
|
||||
```json
|
||||
{
|
||||
"podUsers": [
|
||||
{
|
||||
"code": "zh-CN-XiaoxiaoNeural",
|
||||
"role": "主持人"
|
||||
},
|
||||
{
|
||||
"code": "zh-CN-YunxiNeural",
|
||||
"role": "技术专家"
|
||||
}
|
||||
],
|
||||
"voices": [
|
||||
{
|
||||
"name": "XiaoMin",
|
||||
"alias": "晓敏",
|
||||
"code": "yue-CN-XiaoMinNeural",
|
||||
"locale": "yue-CN",
|
||||
"gender": "Female",
|
||||
"usedname": "晓敏"
|
||||
},
|
||||
{
|
||||
"name": "YunSong",
|
||||
"alias": "云松",
|
||||
"code": "yue-CN-YunSongNeural",
|
||||
"locale": "yue-CN",
|
||||
"gender": "Male",
|
||||
"usedname": "云松"
|
||||
}
|
||||
],
|
||||
"apiUrl": "http://localhost:5000/api/tts?text={{text}}&voiceCode={{voiceCode}}",
|
||||
"turnPattern": "random"
|
||||
}
|
||||
```
|
||||
|
||||
* `podUsers`: 定义播客中的**角色**。每个角色的 `code` 必须对应 `voices` 列表中的一个有效语音。
|
||||
* `voices`: 定义所有可用的 TTS **语音**。
|
||||
* `apiUrl`: 您的 TTS 服务 API 端点。`{{text}}` 将被替换为对话文本,`{{voiceCode}}` 将被替换为角色的语音代码。
|
||||
* `turnPattern`: 定义角色对话的**轮流模式**,例如 `random` (随机) 或 `sequential` (顺序)。
|
||||
|
||||
---
|
||||
|
||||
## 🔌 TTS (Text-to-Speech) 服务集成
|
||||
|
||||
本项目设计为高度灵活,支持多种 TTS 服务,无论是本地部署还是基于云的网络服务,都可以通过简单的配置进行集成。
|
||||
|
||||
### 💻 本地 TTS 接口支持
|
||||
|
||||
您可以将以下开源项目作为本地 TTS 服务部署,并通过 `apiUrl` 配置集成到本项目中:
|
||||
|
||||
* **index-tts**: [https://github.com/index-tts/index-tts](https://github.com/index-tts/index-tts)
|
||||
* **配合使用**: 需要配合 `ext/index-tts-api.py` 文件运行,该文件提供了一个简单的 API 接口,将 `index-tts` 封装为本项目可调用的服务。
|
||||
|
||||
* **edge-tts**: [https://github.com/zuoban/tts](https://github.com/zuoban/tts)
|
||||
* 这是一个通用的 TTS 库,您可以通过自定义适配器将其集成。
|
||||
|
||||
### 🌐 网络 TTS 接口支持(未完成)
|
||||
|
||||
本项目也可以轻松配置集成各种网络 TTS 服务,只需确保您的 `apiUrl` 配置符合服务提供商的要求。常见的支持服务包括:
|
||||
|
||||
* **OpenAI TTS**
|
||||
* **Azure TTS**
|
||||
* **Google Cloud Text-to-Speech (Vertex AI)**
|
||||
* **Minimax TTS**
|
||||
* **Gemini TTS** (可能需要通过自定义 API 适配器集成)
|
||||
* **Fish Audio TTS**
|
||||
|
||||
---
|
||||
|
||||
## 🎉 输出成果
|
||||
|
||||
所有成功生成的播客音频文件将自动保存在 `output/` 目录下。文件名格式为 `podcast_` 加上生成时的时间戳,例如 `podcast_1678886400.wav`。
|
||||
|
||||
---
|
||||
|
||||
---
|
||||
|
||||
## 🎧 示例音频
|
||||
|
||||
您可以在 `example/` 文件夹中找到使用不同 TTS 服务生成的播客示例音频:
|
||||
|
||||
* **Edge TTS 生成示例**: [edgeTTS_podcast_1754467217.aac](example/edgeTTS_podcast_1754467217.aac)
|
||||
* **Index TTS 生成示例**: [indexTTS_podcast_1754467749.aac](example/indexTTS_podcast_1754467749.aac)
|
||||
|
||||
这些音频文件展示了本工具在实际应用中的效果。
|
||||
|
||||
---
|
||||
|
||||
## 📂 文件结构
|
||||
|
||||
```
|
||||
.
|
||||
├── config/ # ⚙️ 配置文件目录
|
||||
│ ├── edge-tts.json
|
||||
│ └── index-tts.json
|
||||
├── prompt/ # 🧠 AI 提示词目录
|
||||
│ ├── prompt-overview.txt
|
||||
│ └── prompt-podscript.txt
|
||||
├── output/ # 🎉 输出音频目录
|
||||
├── input.txt # 🎙️ 播客主题输入文件
|
||||
├── openai_cli.py # OpenAI 命令行工具
|
||||
├── podcast_generator.py # 🚀 主运行脚本
|
||||
└── README.md # 📄 项目说明文档
|
||||
|
||||
```
|
||||
56
check/check_indextts_voices.py
Normal file
56
check/check_indextts_voices.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import json
|
||||
import requests
|
||||
import time
|
||||
import re
|
||||
|
||||
def check_indextts_voices():
|
||||
config_file_path = "config/index-tts.json"
|
||||
test_text = "你好" # 测试文本
|
||||
|
||||
try:
|
||||
with open(config_file_path, 'r', encoding='utf-8') as f:
|
||||
config_data = json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"错误: 配置文件未找到,请检查路径: {config_file_path}")
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
print(f"错误: 无法解析 JSON 文件: {config_file_path}")
|
||||
return
|
||||
|
||||
voices = config_data.get('voices', [])
|
||||
api_url_template = config_data.get('apiUrl')
|
||||
|
||||
if not voices:
|
||||
print("未在配置文件中找到任何声音(voices)。")
|
||||
return
|
||||
|
||||
if not api_url_template:
|
||||
print("未在配置文件中找到 'apiUrl' 字段。")
|
||||
return
|
||||
|
||||
print(f"开始验证 {len(voices)} 个 IndexTTS 语音...")
|
||||
for voice in voices:
|
||||
voice_code = voice.get('code')
|
||||
voice_name = voice.get('alias', voice.get('name', '未知')) # 优先使用 alias, 否则使用 name
|
||||
|
||||
if voice_code:
|
||||
# 替换 URL 模板中的占位符
|
||||
url = api_url_template.replace("{{text}}", test_text).replace("{{voiceCode}}", voice_code)
|
||||
|
||||
print(f"正在测试语音: {voice_name} (Code: {voice_code}) - URL: {url}")
|
||||
try:
|
||||
response = requests.get(url, timeout=10) # 10秒超时
|
||||
if response.status_code == 200:
|
||||
print(f" ✅ {voice_name} (Code: {voice_code}): 可用")
|
||||
else:
|
||||
print(f" ❌ {voice_name} (Code: {voice_code}): 不可用, 状态码: {response.status_code}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f" ❌ {voice_name} (Code: {voice_code}): 请求失败, 错误: {e}")
|
||||
time.sleep(0.1) # 短暂延迟,避免请求过快
|
||||
else:
|
||||
print(f"跳过一个缺少 'code' 字段的语音条目: {voice}")
|
||||
|
||||
print("IndexTTS 语音验证完成。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_indextts_voices()
|
||||
BIN
example/edgeTTS_podcast_1754467217.aac
Normal file
BIN
example/edgeTTS_podcast_1754467217.aac
Normal file
Binary file not shown.
BIN
example/indexTTS_podcast_1754467749.aac
Normal file
BIN
example/indexTTS_podcast_1754467749.aac
Normal file
Binary file not shown.
284
ext/index-tts-api.py
Normal file
284
ext/index-tts-api.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# 用法:
|
||||
# python ./indextts/index-tts-api.py
|
||||
# http://localhost:7899/synthesize?text=Hello world, this is a test using FastAPI&verbose=true&max_text_tokens_per_sentence=80&server_audio_prompt_path=johnny-v.wav
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Optional
|
||||
import re # For sanitizing filenames/paths
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Query, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import FileResponse
|
||||
# Removed File and UploadFile as we are not uploading anymore
|
||||
|
||||
# Assuming infer.py is in the same directory or in PYTHONPATH
|
||||
from infer import IndexTTS
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_CFG_PATH = "checkpoints/config.yaml"
|
||||
MODEL_DIR = "checkpoints"
|
||||
DEFAULT_IS_FP16 = True
|
||||
DEFAULT_USE_CUDA_KERNEL = None
|
||||
DEFAULT_DEVICE = None
|
||||
|
||||
# Default local audio prompt, can be overridden by a query parameter
|
||||
DEFAULT_SERVER_AUDIO_PROMPT_PATH = "prompts/fdt-v.wav" # <-- CHANGE THIS TO YOUR ACTUAL DEFAULT PROMPT
|
||||
# Define a base directory from which user-specified prompts can be loaded
|
||||
# THIS IS A SECURITY MEASURE. Prompts outside this directory (and its subdirs) won't be allowed.
|
||||
ALLOWED_PROMPT_BASE_DIR = os.path.abspath("prompts") # Example: /app/prompts
|
||||
|
||||
# --- Global TTS instance ---
|
||||
tts_model: Optional[IndexTTS] = None
|
||||
|
||||
app = FastAPI(title="IndexTTS FastAPI Service")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global tts_model
|
||||
print("Loading IndexTTS model...")
|
||||
start_load_time = time.time()
|
||||
try:
|
||||
tts_model = IndexTTS(
|
||||
cfg_path=MODEL_CFG_PATH,
|
||||
model_dir=MODEL_DIR,
|
||||
is_fp16=DEFAULT_IS_FP16,
|
||||
device=DEFAULT_DEVICE,
|
||||
use_cuda_kernel=DEFAULT_USE_CUDA_KERNEL,
|
||||
)
|
||||
# Verify default prompt exists
|
||||
if not os.path.isfile(DEFAULT_SERVER_AUDIO_PROMPT_PATH):
|
||||
print(f"WARNING: Default server audio prompt file not found at: {DEFAULT_SERVER_AUDIO_PROMPT_PATH}")
|
||||
|
||||
# Create the allowed prompts directory if it doesn't exist (optional, for convenience)
|
||||
if not os.path.isdir(ALLOWED_PROMPT_BASE_DIR):
|
||||
try:
|
||||
os.makedirs(ALLOWED_PROMPT_BASE_DIR, exist_ok=True)
|
||||
print(f"Created ALLOWED_PROMPT_BASE_DIR: {ALLOWED_PROMPT_BASE_DIR}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not create ALLOWED_PROMPT_BASE_DIR at {ALLOWED_PROMPT_BASE_DIR}: {e}")
|
||||
else:
|
||||
print(f"User-specified prompts will be loaded from: {ALLOWED_PROMPT_BASE_DIR}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading IndexTTS model: {e}")
|
||||
tts_model = None
|
||||
load_time = time.time() - start_load_time
|
||||
print(f"IndexTTS model loaded in {load_time:.2f} seconds.")
|
||||
if tts_model:
|
||||
print(f"Model ready on device: {tts_model.device}")
|
||||
else:
|
||||
print("Model FAILED to load.")
|
||||
|
||||
|
||||
async def cleanup_temp_dir(temp_dir_path: str):
|
||||
try:
|
||||
if os.path.exists(temp_dir_path):
|
||||
shutil.rmtree(temp_dir_path)
|
||||
print(f"Successfully cleaned up temporary directory: {temp_dir_path}")
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up temporary directory {temp_dir_path}: {e}")
|
||||
|
||||
def sanitize_path_component(component: str) -> str:
|
||||
"""Basic sanitization for a path component."""
|
||||
# Remove leading/trailing whitespace and dots
|
||||
component = component.strip().lstrip('.')
|
||||
# Replace potentially harmful characters or sequences
|
||||
component = re.sub(r'\.\.[/\\]', '', component) # Remove .. sequences
|
||||
component = re.sub(r'[<>:"|?*]', '_', component) # Replace illegal filename chars
|
||||
return component
|
||||
|
||||
def get_safe_prompt_path(base_dir: str, user_path: Optional[str]) -> str:
|
||||
"""
|
||||
Constructs a safe path within the base_dir from a user-provided path.
|
||||
Prevents directory traversal.
|
||||
"""
|
||||
if not user_path:
|
||||
return "" # Or raise error if user_path is mandatory when called
|
||||
|
||||
# Normalize user_path (e.g., handle mixed slashes, remove redundant ones)
|
||||
normalized_user_path = os.path.normpath(user_path)
|
||||
|
||||
# Split the path into components and sanitize each one
|
||||
path_components = []
|
||||
head = normalized_user_path
|
||||
while True:
|
||||
head, tail = os.path.split(head)
|
||||
if tail:
|
||||
path_components.insert(0, sanitize_path_component(tail))
|
||||
elif head: # Handle case like "/path" or "path/" leading to empty tail
|
||||
path_components.insert(0, sanitize_path_component(head))
|
||||
break
|
||||
else: # Both head and tail are empty
|
||||
break
|
||||
if not head or head == os.sep or head == '.': # Stop if root or current dir
|
||||
break
|
||||
|
||||
if not path_components:
|
||||
raise ValueError("Invalid or empty prompt path provided after sanitization.")
|
||||
|
||||
# Join sanitized components. This prevents using absolute paths from user_path directly.
|
||||
# os.path.join will correctly use the OS's path separator.
|
||||
# The first component of user_path is NOT joined with base_dir if it's absolute.
|
||||
# We ensure user_path is treated as relative to base_dir.
|
||||
# So, we must ensure path_components doesn't represent an absolute path itself after sanitization.
|
||||
# The sanitize_path_component and os.path.normpath help, but the critical part is os.path.join(base_dir, *path_components)
|
||||
|
||||
# Construct the full path relative to the base directory
|
||||
# *path_components will expand the list into arguments for join
|
||||
prospective_path = os.path.join(base_dir, *path_components)
|
||||
|
||||
# Final check: ensure the resolved path is still within the base_dir
|
||||
# os.path.abspath resolves any '..' etc., in the prospective_path
|
||||
resolved_path = os.path.abspath(prospective_path)
|
||||
if not resolved_path.startswith(os.path.abspath(base_dir)):
|
||||
raise ValueError("Prompt path traversal attempt detected or path is outside allowed directory.")
|
||||
|
||||
return resolved_path
|
||||
|
||||
|
||||
@app.api_route("/synthesize/", methods=["POST", "GET"], response_class=FileResponse)
|
||||
async def synthesize_speech(
|
||||
background_tasks: BackgroundTasks,
|
||||
text: str = Query(..., description="Text to synthesize."),
|
||||
# New parameter for specifying a server-side audio prompt path
|
||||
server_audio_prompt_path: Optional[str] = Query(None, description=f"Relative path to an audio prompt file on the server (within {ALLOWED_PROMPT_BASE_DIR}). If not provided, uses default."),
|
||||
|
||||
verbose: bool = Query(False, description="Enable verbose logging."),
|
||||
max_text_tokens_per_sentence: int = Query(100, description="Max text tokens per sentence."),
|
||||
sentences_bucket_max_size: int = Query(4, description="Sentences bucket max size."),
|
||||
do_sample: bool = Query(True, description="Enable sampling."),
|
||||
top_p: float = Query(0.8, description="Top P for sampling."),
|
||||
top_k: int = Query(30, description="Top K for sampling."),
|
||||
temperature: float = Query(1.0, description="Temperature for sampling."),
|
||||
length_penalty: float = Query(0.0, description="Length penalty."),
|
||||
num_beams: int = Query(3, description="Number of beams for beam search."),
|
||||
repetition_penalty: float = Query(10.0, description="Repetition penalty."),
|
||||
max_mel_tokens: int = Query(600, description="Max mel tokens to generate.")
|
||||
):
|
||||
if tts_model is None:
|
||||
raise HTTPException(status_code=503, detail="TTS model is not loaded or failed to load.")
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
actual_audio_prompt_to_use = "" # This will be the path on the server filesystem
|
||||
|
||||
try:
|
||||
if server_audio_prompt_path:
|
||||
print(f"Client specified server_audio_prompt_path: {server_audio_prompt_path}")
|
||||
# Auto-complete .wav extension if missing
|
||||
if server_audio_prompt_path and not server_audio_prompt_path.lower().endswith(".wav"):
|
||||
print(f"server_audio_prompt_path '{server_audio_prompt_path}' does not end with .wav, appending it.")
|
||||
server_audio_prompt_path += ".wav"
|
||||
try:
|
||||
# Sanitize and resolve the user-provided path against the allowed base directory
|
||||
safe_path = get_safe_prompt_path(ALLOWED_PROMPT_BASE_DIR, server_audio_prompt_path)
|
||||
if os.path.isfile(safe_path):
|
||||
actual_audio_prompt_to_use = safe_path
|
||||
print(f"Using user-specified server prompt: {actual_audio_prompt_to_use}")
|
||||
else:
|
||||
await cleanup_temp_dir(temp_dir)
|
||||
raise HTTPException(status_code=404, detail=f"Specified server audio prompt not found or not a file: {safe_path} (original: {server_audio_prompt_path})")
|
||||
except ValueError as ve: # From get_safe_prompt_path for security violations
|
||||
await cleanup_temp_dir(temp_dir)
|
||||
raise HTTPException(status_code=400, detail=f"Invalid server_audio_prompt_path: {str(ve)}")
|
||||
else:
|
||||
print(f"Using default server audio prompt: {DEFAULT_SERVER_AUDIO_PROMPT_PATH}")
|
||||
if not os.path.isfile(DEFAULT_SERVER_AUDIO_PROMPT_PATH):
|
||||
await cleanup_temp_dir(temp_dir)
|
||||
raise HTTPException(status_code=500, detail=f"Default server audio prompt file not found: {DEFAULT_SERVER_AUDIO_PROMPT_PATH}. Please configure the server.")
|
||||
actual_audio_prompt_to_use = DEFAULT_SERVER_AUDIO_PROMPT_PATH
|
||||
|
||||
# Copy the chosen prompt (either user-specified or default) to the temp_dir.
|
||||
# This keeps the subsequent logic (model interaction, cleanup) consistent.
|
||||
# It also means the original prompt files are not directly modified or locked.
|
||||
prompt_filename_for_temp = os.path.basename(actual_audio_prompt_to_use)
|
||||
temp_audio_prompt_path_in_job_dir = os.path.join(temp_dir, prompt_filename_for_temp)
|
||||
try:
|
||||
shutil.copy2(actual_audio_prompt_to_use, temp_audio_prompt_path_in_job_dir)
|
||||
except Exception as e:
|
||||
await cleanup_temp_dir(temp_dir)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to copy audio prompt to temporary workspace: {str(e)}")
|
||||
|
||||
|
||||
output_filename = f"generated_speech_{int(time.time())}.wav"
|
||||
temp_output_path = os.path.join(temp_dir, output_filename)
|
||||
|
||||
print(f"Synthesizing for text: '{text[:50]}...' with prompt (in temp): {temp_audio_prompt_path_in_job_dir}")
|
||||
print(f"Output will be saved to: {temp_output_path}")
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": do_sample,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"temperature": temperature,
|
||||
"length_penalty": length_penalty,
|
||||
"num_beams": num_beams,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"max_mel_tokens": max_mel_tokens,
|
||||
}
|
||||
|
||||
start_infer_time = time.time()
|
||||
returned_output_path = tts_model.infer_fast(
|
||||
audio_prompt=temp_audio_prompt_path_in_job_dir, # Use the copied prompt in temp dir
|
||||
text=text,
|
||||
output_path=temp_output_path,
|
||||
verbose=verbose,
|
||||
max_text_tokens_per_sentence=max_text_tokens_per_sentence,
|
||||
sentences_bucket_max_size=sentences_bucket_max_size,
|
||||
**generation_kwargs
|
||||
)
|
||||
infer_time = time.time() - start_infer_time
|
||||
print(f"Inference completed in {infer_time:.2f} seconds. Expected output: {temp_output_path}, Returned path: {returned_output_path}")
|
||||
|
||||
if not os.path.exists(temp_output_path):
|
||||
print(f"ERROR: Output file {temp_output_path} was NOT found after inference call.")
|
||||
background_tasks.add_task(cleanup_temp_dir, temp_dir)
|
||||
raise HTTPException(status_code=500, detail="TTS synthesis failed to produce an output file.")
|
||||
|
||||
print(f"Output file {temp_output_path} confirmed to exist.")
|
||||
background_tasks.add_task(cleanup_temp_dir, temp_dir)
|
||||
|
||||
return FileResponse(
|
||||
path=temp_output_path,
|
||||
media_type="audio/wav",
|
||||
filename="synthesized_audio.wav"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during synthesis: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
if 'temp_dir' in locals() and os.path.exists(temp_dir):
|
||||
background_tasks.add_task(cleanup_temp_dir, temp_dir)
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def read_root():
|
||||
return {"message": "IndexTTS FastAPI service is running. Use the /synthesize/ endpoint (GET or POST) to generate audio."}
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.path.isfile(DEFAULT_SERVER_AUDIO_PROMPT_PATH):
|
||||
print(f"CRITICAL WARNING: Default server audio prompt at '{DEFAULT_SERVER_AUDIO_PROMPT_PATH}' not found!")
|
||||
else:
|
||||
print(f"Default server audio prompt: {os.path.abspath(DEFAULT_SERVER_AUDIO_PROMPT_PATH)}")
|
||||
|
||||
if not os.path.isdir(ALLOWED_PROMPT_BASE_DIR):
|
||||
print(f"WARNING: ALLOWED_PROMPT_BASE_DIR '{ALLOWED_PROMPT_BASE_DIR}' does not exist. Consider creating it or prompts specified by 'server_audio_prompt_path' may not be found.")
|
||||
else:
|
||||
print(f"User-specified prompts should be relative to: {os.path.abspath(ALLOWED_PROMPT_BASE_DIR)}")
|
||||
|
||||
|
||||
print(f"Attempting to use MODEL_DIR: {os.path.abspath(MODEL_DIR)}")
|
||||
print(f"Attempting to use MODEL_CFG_PATH: {os.path.abspath(MODEL_CFG_PATH)}")
|
||||
|
||||
if not os.path.isdir(MODEL_DIR):
|
||||
print(f"ERROR: MODEL_DIR '{MODEL_DIR}' not found. Please check the path.")
|
||||
if not os.path.isfile(MODEL_CFG_PATH):
|
||||
print(f"ERROR: MODEL_CFG_PATH '{MODEL_CFG_PATH}' not found. Please check the path.")
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7899)
|
||||
Reference in New Issue
Block a user