From 00c5dfadee854a2a9ac5f4265eea30175f5cceef Mon Sep 17 00:00:00 2001 From: glidea <740696441@qq.com> Date: Tue, 8 Jul 2025 18:13:26 +0800 Subject: [PATCH] add podcast --- .github/workflows/ci.yml | 14 +- .gitignore | 2 +- docs/config-zh.md | 58 +++++-- docs/config.md | 58 +++++-- docs/podcast.md | 104 +++++++++++ go.mod | 13 +- go.sum | 23 +++ main.go | 46 +++-- pkg/config/config.go | 31 +++- pkg/llm/gemini.go | 248 +++++++++++++++++++++++++++ pkg/llm/llm.go | 74 +++++--- pkg/llm/openai.go | 41 +++-- pkg/rewrite/rewrite.go | 267 ++++++++++++++++++++++++----- pkg/rewrite/rewrite_test.go | 176 +++++++++++++++++-- pkg/scrape/manager.go | 4 + pkg/scrape/scraper/scraper.go | 5 + pkg/scrape/scraper/scraper_test.go | 4 +- pkg/storage/object/object.go | 239 ++++++++++++++++++++++++++ pkg/util/binary/binary.go | 26 +++ pkg/util/wav/wav.go | 100 +++++++++++ pkg/util/wav/wav_test.go | 161 +++++++++++++++++ 21 files changed, 1549 insertions(+), 145 deletions(-) create mode 100644 docs/podcast.md create mode 100644 pkg/llm/gemini.go create mode 100644 pkg/storage/object/object.go create mode 100644 pkg/util/wav/wav.go create mode 100644 pkg/util/wav/wav_test.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0a757f1..858fcff 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,11 +2,9 @@ name: CI on: push: - branches: [ main ] + branches: [ main, dev ] pull_request: branches: [ main ] - release: - types: [ published ] jobs: test: @@ -27,7 +25,7 @@ jobs: build-and-push: runs-on: ubuntu-latest needs: test - if: github.event_name == 'release' + if: github.event_name == 'push' steps: - uses: actions/checkout@v4 - name: Set up Docker Buildx @@ -37,5 +35,9 @@ jobs: with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Build and push Docker images - run: make push \ No newline at end of file + - name: Build and push Docker image (main) + if: github.ref_name == 'main' + run: make push + - name: Build and push Docker image (dev) + if: github.ref_name == 'dev' + run: make dev-push \ No newline at end of file diff --git a/.gitignore b/.gitignore index 9558782..48d3e7a 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,7 @@ local_docs/ .env .env.local __debug_bin -config.yaml +config.*yaml data/ *debug* .cursorrules diff --git a/docs/config-zh.md b/docs/config-zh.md index 4ede3d6..4b7387d 100644 --- a/docs/config-zh.md +++ b/docs/config-zh.md @@ -41,6 +41,7 @@ | `llms[].api_key` | `string` | LLM 的 API 密钥。 | | 是 | | `llms[].model` | `string` | LLM 的模型。例如 `gpt-4o-mini`。如果用于生成任务 (如总结),则不能为空。如果此 LLM 被使用,则不能与 `embedding_model` 同时为空。 | | 条件性必需 | | `llms[].embedding_model` | `string` | LLM 的 Embedding 模型。例如 `text-embedding-3-small`。如果用于 Embedding,则不能为空。如果此 LLM 被使用,则不能与 `model` 同时为空。**注意:** 初次使用后请勿直接修改,应添加新的 LLM 配置。 | | 条件性必需 | +| `llms[].tts_model` | `string` | LLM 的文本转语音 (TTS) 模型。 | | 否 | | `llms[].temperature` | `float32` | LLM 的温度 (0-2)。 | `0.0` | 否 | ### Jina AI 配置 (`jina`) @@ -80,10 +81,11 @@ ### 存储配置 (`storage`) -| 字段 | 类型 | 描述 | 默认值 | 是否必需 | -| :------------- | :------- | :-------------------------------------------- | :----------- | :------- | -| `storage.dir` | `string` | 所有存储的基础目录。应用运行后不可更改。 | `./data` | 否 | -| `storage.feed` | `object` | Feed 存储配置。详见下方的 **Feed 存储配置**。 | (见具体字段) | 否 | +| 字段 | 类型 | 描述 | 默认值 | 是否必需 | +| :--------------- | :------- | :-------------------------------------------------------------- | :----------- | :------- | +| `storage.dir` | `string` | 所有存储的基础目录。应用运行后不可更改。 | `./data` | 否 | +| `storage.feed` | `object` | Feed 存储配置。详见下方的 **Feed 存储配置**。 | (见具体字段) | 否 | +| `storage.object` | `object` | 对象存储配置,用于存储播客等文件。详见下方的 **对象存储配置**。 | (见具体字段) | 否 | ### Feed 存储配置 (`storage.feed`) @@ -95,6 +97,16 @@ | `storage.feed.retention` | `time.Duration` | Feed 的保留时长。 | `8d` | 否 | | `storage.feed.block_duration` | `time.Duration` | 每个基于时间的 Feed 存储块的保留时长 (类似于 Prometheus TSDB Block)。 | `25h` | 否 | +### 对象存储配置 (`storage.object`) + +| 字段 | 类型 | 描述 | 默认值 | 是否必需 | +| :--------------------------------- | :------- | :----------------------------- | :----- | :-------------------- | +| `storage.object.endpoint` | `string` | 对象存储的端点。 | | 是 (如果使用播客功能) | +| `storage.object.access_key_id` | `string` | 对象存储的 Access Key ID。 | | 是 (如果使用播客功能) | +| `storage.object.secret_access_key` | `string` | 对象存储的 Secret Access Key。 | | 是 (如果使用播客功能) | +| `storage.object.bucket` | `string` | 对象存储的存储桶名称。 | | 是 (如果使用播客功能) | +| `storage.object.bucket_url` | `string` | 对象存储的桶访问 URL。 | | 否 | + ### 重写规则配置 (`storage.feed.rewrites[]`) 定义在存储前处理 Feed 的规则。规则按顺序应用。 @@ -109,12 +121,8 @@ | `...rewrites[].match_re` | `string` | 用于匹配 (转换后) 文本的正则表达式。 | `.*` (匹配所有) | 否 (使用 `match` 或 `match_re`) | | `...rewrites[].action` | `string` | 匹配时执行的操作: `create_or_update_label` (使用匹配/转换后的文本添加/更新标签), `drop_feed` (完全丢弃该 Feed)。 | `create_or_update_label` | 否 | | `...rewrites[].label` | `string` | 要创建或更新的 Feed 标签名称。 | | 是 (如果 `action` 是 `create_or_update_label`) | - -### 重写规则转换配置 (`storage.feed.rewrites[].transform`) - -| 字段 | 类型 | 描述 | 默认值 | 是否必需 | -| :--------------------- | :------- | :------------------------------------------------------------------- | :----- | :------- | -| `...transform.to_text` | `object` | 使用 LLM 将源文本转换为文本。详见下方的 **重写规则转换为文本配置**。 | `nil` | 否 | +| `...transform.to_text` | `object` | 使用 LLM 将源文本转换为文本。详见下方的 **重写规则转换为文本配置**。 | `nil` | 否 | +| `...transform.to_podcast` | `object` | 将源文本转换为播客。详见下方的 **重写规则转换为播客配置**。 | `nil` | 否 | ### 重写规则转换为文本配置 (`storage.feed.rewrites[].transform.to_text`) @@ -126,6 +134,25 @@ | `...to_text.llm` | `string` | **仅当 `type` 为 `prompt` 时有效。** 用于转换的 LLM 名称 (来自 `llms` 部分)。如果未指定,将使用在 `llms` 部分中标记为 `default: true` 的 LLM。 | `llms` 部分中的默认 LLM | 否 | | `...to_text.prompt` | `string` | **仅当 `type` 为 `prompt` 时有效。** 用于转换的 Prompt。源文本将被注入。可以使用 Go 模板语法引用内置 Prompt: `{{ .summary }}`, `{{ .category }}`, `{{ .tags }}`, `{{ .score }}`, `{{ .comment_confucius }}`, `{{ .summary_html_snippet }}`, `{{ .summary_html_snippet_for_small_model }}`。 | | 是 (如果 `type` 是 `prompt`) | +### 重写规则转换为播客配置 (`storage.feed.rewrites[].transform.to_podcast`) + +此配置定义了如何将 `source_label` 的文本转换为播客。 + +| 字段 | 类型 | 描述 | 默认值 | 是否必需 | +| :------------------------------------------- | :--------- | :-------------------------------------------------------------------------------------------------------- | :---------------------- | :------- | +| `...to_podcast.llm` | `string` | 用于生成播客稿件的 LLM 名称 (来自 `llms` 部分)。 | `llms` 部分中的默认 LLM | 否 | +| `...to_podcast.transcript_additional_prompt` | `string` | 附加到播客稿件生成 Prompt 的额外指令。 | | 否 | +| `...to_podcast.tts_llm` | `string` | 用于文本转语音 (TTS) 的 LLM 名称 (来自 `llms` 部分)。**注意:目前仅支持 `provider` 为 `gemini` 的 LLM**。 | `llms` 部分中的默认 LLM | 否 | +| `...to_podcast.speakers` | `对象列表` | 播客的演讲者列表。详见下方的 **演讲者配置**。 | `[]` | 是 | + +#### 演讲者配置 (`...to_podcast.speakers[]`) + +| 字段 | 类型 | 描述 | 默认值 | 是否必需 | +| :-------------------- | :------- | :------------------------ | :----- | :------- | +| `...speakers[].name` | `string` | 演讲者的名字。 | | 是 | +| `...speakers[].role` | `string` | 演讲者的角色描述 (人设)。 | | 否 | +| `...speakers[].voice` | `string` | 演讲者的声音。 | | 是 | + ### 调度配置 (`scheduls`) 定义查询和监控 Feed 的规则。 @@ -173,10 +200,11 @@ 定义*谁*接收通知。 -| 字段 | 类型 | 描述 | 默认值 | 是否必需 | -| :------------------------- | :------- | :------------------------------- | :----- | :------------------ | -| `notify.receivers[].name` | `string` | 接收者的唯一名称。在路由中使用。 | | 是 | -| `notify.receivers[].email` | `string` | 接收者的电子邮件地址。 | | 是 (如果使用 Email) | +| 字段 | 类型 | 描述 | 默认值 | 是否必需 | +| :--------------------------- | :------- | :------------------------------------------------------- | :----- | :-------------------- | +| `notify.receivers[].name` | `string` | 接收者的唯一名称。在路由中使用。 | | 是 | +| `notify.receivers[].email` | `string` | 接收者的电子邮件地址。 | | 是 (如果使用 Email) | +| `notify.receivers[].webhook` | `object` | 接收者的 Webhook 配置。例如: `webhook: { "url": "xxx" }` | | 是 (如果使用 Webhook) | ### 通知渠道配置 (`notify.channels`) @@ -194,4 +222,4 @@ | `...email.from` | `string` | 发件人 Email 地址。 | | 是 | | `...email.password` | `string` | 发件人 Email 的应用专用密码。(对于 Gmail, 参见 [Google 应用密码](https://support.google.com/mail/answer/185833))。 | | 是 | | `...email.feed_markdown_template` | `string` | 用于在 Email 正文中格式化每个 Feed 的 Markdown 模板。默认渲染 Feed 内容。不能与 `feed_html_snippet_template` 同时设置。可用的模板变量取决于 Feed 标签。 | `{{ .content }}` | 否 | -| `...email.feed_html_snippet_template` | `string` | 用于格式化每个 Feed 的 HTML 片段模板。不能与 `feed_markdown_template` 同时设置。可用的模板变量取决于 Feed 标签。 | | 否 | \ No newline at end of file +| `...email.feed_html_snippet_template` | `string` | 用于格式化每个 Feed 的 HTML 片段模板。不能与 `feed_markdown_template` 同时设置。可用的模板变量取决于 Feed 标签。 | | 否 | diff --git a/docs/config.md b/docs/config.md index 3a88ef7..78fd800 100644 --- a/docs/config.md +++ b/docs/config.md @@ -41,6 +41,7 @@ This section defines the list of available Large Language Models. At least one L | `llms[].api_key` | `string` | API key for the LLM. | | Yes | | `llms[].model` | `string` | Model of the LLM. E.g., `gpt-4o-mini`. Cannot be empty if used for generation tasks (e.g., summarization). If this LLM is used, cannot be empty along with `embedding_model`. | | Conditionally Required | | `llms[].embedding_model` | `string` | Embedding model of the LLM. E.g., `text-embedding-3-small`. Cannot be empty if used for embedding. If this LLM is used, cannot be empty along with `model`. **Note:** Do not modify directly after initial use; add a new LLM configuration instead. | | Conditionally Required | +| `llms[].tts_model` | `string` | The Text-to-Speech (TTS) model of the LLM. | | No | | `llms[].temperature` | `float32` | Temperature of the LLM (0-2). | `0.0` | No | ### Jina AI Configuration (`jina`) @@ -80,10 +81,11 @@ Describes each source to be scraped. ### Storage Configuration (`storage`) -| Field | Type | Description | Default Value | Required | -| :------------- | :------- | :------------------------------------------------------------------------------ | :-------------------- | :------- | -| `storage.dir` | `string` | Base directory for all storage. Cannot be changed after the application starts. | `./data` | No | -| `storage.feed` | `object` | Feed storage configuration. See **Feed Storage Configuration** below. | (See specific fields) | No | +| Field | Type | Description | Default Value | Required | +| :--------------- | :------- | :-------------------------------------------------------------------------------------------------------- | :-------------------- | :------- | +| `storage.dir` | `string` | Base directory for all storage. Cannot be changed after the application starts. | `./data` | No | +| `storage.feed` | `object` | Feed storage configuration. See **Feed Storage Configuration** below. | (See specific fields) | No | +| `storage.object` | `object` | Object storage configuration for storing files like podcasts. See **Object Storage Configuration** below. | (See specific fields) | No | ### Feed Storage Configuration (`storage.feed`) @@ -95,6 +97,16 @@ Describes each source to be scraped. | `storage.feed.retention` | `time.Duration` | Retention duration for feeds. | `8d` | No | | `storage.feed.block_duration` | `time.Duration` | Retention duration for each time-based feed storage block (similar to Prometheus TSDB Block). | `25h` | No | +### Object Storage Configuration (`storage.object`) + +| Field | Type | Description | Default Value | Required | +| :--------------------------------- | :------- | :------------------------------------------- | :------------ | :----------------------------- | +| `storage.object.endpoint` | `string` | The endpoint of the object storage. | | Yes (if using podcast feature) | +| `storage.object.access_key_id` | `string` | The access key id of the object storage. | | Yes (if using podcast feature) | +| `storage.object.secret_access_key` | `string` | The secret access key of the object storage. | | Yes (if using podcast feature) | +| `storage.object.bucket` | `string` | The bucket of the object storage. | | Yes (if using podcast feature) | +| `storage.object.bucket` | `string` | The URL of the object storage bucket. | | No | + ### Rewrite Rule Configuration (`storage.feed.rewrites[]`) Defines rules to process feeds before storage. Rules are applied sequentially. @@ -109,12 +121,8 @@ Defines rules to process feeds before storage. Rules are applied sequentially. | `...rewrites[].match_re` | `string` | Regular expression to match against the (transformed) text. | `.*` (matches all) | No (use `match` or `match_re`) | | `...rewrites[].action` | `string` | Action to perform on match: `create_or_update_label` (adds/updates a label with the matched/transformed text), `drop_feed` (discards the feed entirely). | `create_or_update_label` | No | | `...rewrites[].label` | `string` | Name of the feed label to create or update. | | Yes (if `action` is `create_or_update_label`) | - -### Rewrite Rule Transform Configuration (`storage.feed.rewrites[].transform`) - -| Field | Type | Description | Default Value | Required | -| :--------------------- | :------- | :--------------------------------------------------------------------------------------------- | :------------ | :------- | -| `...transform.to_text` | `object` | Transforms source text to text using an LLM. See **Rewrite Rule To Text Configuration** below. | `nil` | No | +| `...transform.to_text` | `object` | Transforms source text to text using an LLM. See **Rewrite Rule To Text Configuration** below. | `nil` | No | +| `...transform.to_podcast` | `object` | Transforms source text to a podcast. See **Rewrite Rule To Podcast Configuration** below. | `nil` | No | ### Rewrite Rule To Text Configuration (`storage.feed.rewrites[].transform.to_text`) @@ -124,7 +132,26 @@ This configuration defines how to transform the text from `source_label`. | :------------------ | :------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------------------------- | :-------------------------- | | `...to_text.type` | `string` | Type of transformation. Options: | `prompt` | No | | `...to_text.llm` | `string` | **Only valid if `type` is `prompt`.** Name of the LLM used for transformation (from `llms` section). If not specified, the LLM marked as `default: true` in the `llms` section will be used. | Default LLM in `llms` section | No | -| `...to_text.prompt` | `string` | **Only valid if `type` is `prompt`.** Prompt used for transformation. The source text will be injected. You can use Go template syntax to reference built-in prompts: `{{ .summary }}`, `{{ .category }}`, `{{ .tags }}`, `{{ .score }}`, `{{ .comment_confucius }}`, `{{ .summary_html_snippet }}`. | | Yes (if `type` is `prompt`) | +| `...to_text.prompt` | `string` | **Only valid if `type` is `prompt`.** Prompt used for transformation. The source text will be injected. You can use Go template syntax to reference built-in prompts: `{{ .summary }}`, `{{ .category }}`, `{{ .tags }}`, `{{ .score }}`, `{{ .comment_confucius }}`, `{{ .summary_html_snippet }}`, `{{ .summary_html_snippet_for_small_model }}`. | | Yes (if `type` is `prompt`) | + +### Rewrite Rule To Podcast Configuration (`storage.feed.rewrites[].transform.to_podcast`) + +This configuration defines how to transform the text from `source_label` into a podcast. + +| Field | Type | Description | Default Value | Required | +| :------------------------------------------- | :---------------- | :--------------------------------------------------------------------------------------------------------------------------------------------- | :---------------------------- | :------- | +| `...to_podcast.llm` | `string` | The name of the LLM (from the `llms` section) to use for generating the podcast script. | Default LLM in `llms` section | No | +| `...to_podcast.transcript_additional_prompt` | `string` | Additional instructions to append to the prompt for generating the podcast script. | | No | +| `...to_podcast.tts_llm` | `string` | The name of the LLM (from the `llms` section) to use for Text-to-Speech (TTS). **Note: Currently only supports LLMs with `provider: gemini`**. | Default LLM in `llms` section | No | +| `...to_podcast.speakers` | `list of objects` | A list of speakers for the podcast. See **Speaker Configuration** below. | `[]` | Yes | + +#### Speaker Configuration (`...to_podcast.speakers[]`) + +| Field | Type | Description | Default Value | Required | +| :-------------------- | :------- | :----------------------------------- | :------------ | :------- | +| `...speakers[].name` | `string` | The name of the speaker. | | Yes | +| `...speakers[].role` | `string` | The role description of the speaker. | | No | +| `...speakers[].voice` | `string` | The voice of the speaker. | | Yes | ### Scheduling Configuration (`scheduls`) @@ -173,10 +200,11 @@ This structure can be nested using `sub_routes`. Feeds will first try to match s Defines *who* receives notifications. -| Field | Type | Description | Default Value | Required | -| :------------------------- | :------- | :------------------------------------------- | :------------ | :------------------- | -| `notify.receivers[].name` | `string` | Unique name of the receiver. Used in routes. | | Yes | -| `notify.receivers[].email` | `string` | Email address of the receiver. | | Yes (if using Email) | +| Field | Type | Description | Default Value | Required | +| :--------------------------- | :------- | :----------------------------------------------------------------------- | :------------ | :--------------------- | +| `notify.receivers[].name` | `string` | Unique name of the receiver. Used in routes. | | Yes | +| `notify.receivers[].email` | `string` | Email address of the receiver. | | Yes (if using Email) | +| `notify.receivers[].webhook` | `object` | Webhook configuration for the receiver. E.g. `webhook: { "url": "xxx" }` | | Yes (if using Webhook) | ### Notification Channel Configuration (`notify.channels`) diff --git a/docs/podcast.md b/docs/podcast.md new file mode 100644 index 0000000..84013e1 --- /dev/null +++ b/docs/podcast.md @@ -0,0 +1,104 @@ +# 使用 Zenfeed 将文章转换为播客 + +Zenfeed 的播客功能可以将任何文章源自动转换为一场引人入胜的多人对话式播客。该功能利用大语言模型(LLM)生成对话脚本和文本转语音(TTS),并将最终的音频文件托管在您自己的对象存储中。 + +## 工作原理 + +1. **提取内容**: Zenfeed 首先通过重写规则提取文章的全文内容。 +2. **生成脚本**: 使用一个指定的 LLM(如 GPT-4o-mini)将文章内容改编成一个由多位虚拟主播对话的脚本。您可以定义每个主播的角色(人设)来控制对话风格。 +3. **语音合成**: 调用另一个支持 TTS 的 LLM(目前仅支持 Google Gemini)将脚本中的每一句对话转换为语音。 +4. **音频合并**: 将所有语音片段合成为一个完整的 WAV 音频文件。 +5. **上传存储**: 将生成的播客文件上传到您配置的 S3 兼容对象存储中。 +6. **保存链接**: 最后,将播客文件的公开访问 URL 保存为一个新的 Feed 标签,方便您在通知、API 或其他地方使用。 + +## 配置步骤 + +要启用播客功能,您需要完成以下三项配置:LLM、对象存储和重写规则。 + +### 1. 配置 LLM + +您需要至少配置两个 LLM:一个用于生成对话脚本,另一个用于文本转语音(TTS)。 + +- **脚本生成 LLM**: 可以是任何性能较好的聊天模型,例如 OpenAI 的 `gpt-4o-mini` 或 Google 的 `gemini-1.5-pro`。 +- **TTS LLM**: 用于将文本转换为语音。**注意:目前此功能仅支持 `provider` 为 `gemini` 的 LLM。** + +**示例 `config.yaml`:** + +```yaml +llms: + # 用于生成播客脚本的 LLM + - name: openai-chat + provider: openai + api_key: "sk-..." + model: gpt-4o-mini + default: true + + # 用于文本转语音 (TTS) 的 LLM + - name: gemini-tts + provider: gemini + api_key: "..." # 你的 Google AI Studio API Key + tts_model: "gemini-2.5-flash-preview-tts" # Gemini 的 TTS 模型 +``` + +### 2. 配置对象存储 + +生成的播客音频文件需要一个地方存放。Zenfeed 支持任何 S3 兼容的对象存储服务。这里我们以 [Cloudflare R2](https://www.cloudflare.com/zh-cn/products/r2/) 为例。 + +首先,您需要在 Cloudflare R2 中创建一个存储桶(Bucket)。然后获取以下信息: + +- `endpoint`: 您的 R2 API 端点。通常格式为 `.r2.cloudflarestorage.com`。您可以在 R2 存储桶的主页找到它。 +- `access_key_id` 和 `secret_access_key`: R2 API 令牌。您可以在 "R2" -> "管理 R2 API 令牌" 页面创建。 +- `bucket`: 您创建的存储桶的名称。 +- `bucket_url`: 存储桶的公开访问 URL。要获取此 URL,您需要将存储桶连接到一个自定义域,或者使用 R2 提供的 `r2.dev` 公开访问地址。 + +**示例 `config.yaml`:** + +```yaml +storage: + object: + endpoint: ".r2.cloudflarestorage.com" + access_key_id: "..." + secret_access_key: "..." + bucket: "zenfeed-podcasts" + bucket_url: "https://pub-xxxxxxxx.r2.dev" +``` + +### 3. 配置重写规则 + +最后一步是创建一个重写规则,告诉 Zenfeed 如何将文章转换为播客。这个规则定义了使用哪个标签作为源文本、由谁来对话、使用什么声音等。 + +**关键配置项:** + +- `source_label`: 包含文章全文的标签。 +- `label`: 用于存储最终播客 URL 的新标签名称。 +- `transform.to_podcast`: 播客转换的核心配置。 + - `llm`: 用于生成脚本的 LLM 名称(来自 `llms` 配置)。 + - `tts_llm`: 用于 TTS 的 LLM 名称(来自 `llms` 配置)。 + - `speakers`: 定义播客的演讲者。 + - `name`: 演讲者的名字。 + - `role`: 演讲者的角色和人设,将影响脚本内容。 + - `voice`: 演讲者的声音。请参考 [Google Cloud TTS 文档](https://cloud.google.com/text-to-speech/docs/voices) 获取可用的声音名称(例如 `en-US-Standard-C`,`en-US-News-N`)。 + +**示例 `config.yaml`:** + +```yaml +storage: + feed: + rewrites: + - source_label: "content" + label: "podcast_url" + transform: + to_podcast: + llm: "openai-chat" + tts_llm: "gemini-tts" + transcript_additional_prompt: "使用中文回复" + speakers: + - name: "主持人小雅" + role: "一位经验丰富、声音甜美、风格活泼的科技播客主持人。擅长联系实际生活场景。" + voice: "zh-CN-Standard-A" # 女声 + - name: "技术评论员老王" + role: "一位对技术有深入见解、观点犀利的评论员,说话直接,偶尔有些愤世嫉俗。" + voice: "zh-CN-Standard-B" # 男声 +``` + +配置完成后,Zenfeed 将在每次抓取到新文章时,自动执行上述流程。可以在通知模版中使用 podcast_url label,或 Web 中直接收听(Web 固定读取 podcast_url label,若使用别的名称则无法读取) \ No newline at end of file diff --git a/go.mod b/go.mod index 140f8d3..520bf66 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/edsrzf/mmap-go v1.2.0 github.com/gorilla/feeds v1.2.0 github.com/mark3labs/mcp-go v0.17.0 + github.com/minio/minio-go/v7 v7.0.94 github.com/mmcdole/gofeed v1.3.0 github.com/nutsdb/nutsdb v1.0.4 github.com/onsi/gomega v1.36.1 @@ -16,6 +17,7 @@ require ( github.com/prometheus/client_golang v1.21.1 github.com/sashabaranov/go-openai v1.40.1 github.com/stretchr/testify v1.10.0 + github.com/temoto/robotstxt v1.1.2 github.com/veqryn/slog-dedup v0.5.0 github.com/yuin/goldmark v1.7.8 gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df @@ -32,25 +34,34 @@ require ( github.com/bwmarrin/snowflake v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-ini/ini v1.67.0 // indirect + github.com/goccy/go-json v0.10.5 // indirect github.com/gofrs/flock v0.8.1 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.10 // indirect + github.com/minio/crc64nvme v1.0.1 // indirect + github.com/minio/md5-simd v1.1.2 // indirect github.com/mmcdole/goxpp v1.1.1-0.20240225020742-a0c311522b23 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect + github.com/rs/xid v1.6.0 // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/temoto/robotstxt v1.1.2 github.com/tidwall/btree v1.6.0 // indirect + github.com/tinylib/msgp v1.3.0 // indirect github.com/xujiajun/mmap-go v1.0.1 // indirect github.com/xujiajun/utils v0.0.0-20220904132955-5f7c5b914235 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/crypto v0.36.0 // indirect golang.org/x/net v0.38.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect diff --git a/go.sum b/go.sum index 6b5ab8a..837c5bd 100644 --- a/go.sum +++ b/go.sum @@ -21,12 +21,18 @@ github.com/chewxy/math32 v1.10.1/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUw github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/edsrzf/mmap-go v1.2.0 h1:hXLYlkbaPzt1SaQk+anYwKSRNhufIDCchSPkUD6dD84= github.com/edsrzf/mmap-go v1.2.0/go.mod h1:19H/e8pUPLicwkyNgOykDXkJ9F0MHE+Z52B8EIth78Q= +github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= +github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -42,6 +48,9 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -53,6 +62,12 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/mark3labs/mcp-go v0.17.0 h1:5Ps6T7qXr7De/2QTqs9h6BKeZ/qdeUeGrgM5lPzi930= github.com/mark3labs/mcp-go v0.17.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= +github.com/minio/crc64nvme v1.0.1 h1:DHQPrYPdqK7jQG/Ls5CTBZWeex/2FMS3G5XGkycuFrY= +github.com/minio/crc64nvme v1.0.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg= +github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= +github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= +github.com/minio/minio-go/v7 v7.0.94 h1:1ZoksIKPyaSt64AVOyaQvhDOgVC3MfZsWM6mZXRUGtM= +github.com/minio/minio-go/v7 v7.0.94/go.mod h1:71t2CqDt3ThzESgZUlU1rBN54mksGGlkLcFgguDnnAc= github.com/mmcdole/gofeed v1.3.0 h1:5yn+HeqlcvjMeAI4gu6T+crm7d0anY85+M+v6fIFNG4= github.com/mmcdole/gofeed v1.3.0/go.mod h1:9TGv2LcJhdXePDzxiuMnukhV2/zb6VtnZt1mS+SjkLE= github.com/mmcdole/goxpp v1.1.1-0.20240225020742-a0c311522b23 h1:Zr92CAlFhy2gL+V1F+EyIuzbQNbSgP4xhTODZtrXUtk= @@ -70,6 +85,8 @@ github.com/onsi/ginkgo/v2 v2.20.1 h1:YlVIbqct+ZmnEph770q9Q7NVAz4wwIiVNahee6JyUzo github.com/onsi/ginkgo/v2 v2.20.1/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI= github.com/onsi/gomega v1.36.1 h1:bJDPBO7ibjxcbHMgSCoo4Yj18UWbKDlLwX1x9sybDcw= github.com/onsi/gomega v1.36.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog= +github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1GshSTtih8C2gDs04w8dReiOGXrGLNoY= +github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -87,6 +104,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6O github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/sashabaranov/go-openai v1.40.1 h1:bJ08Iwct5mHBVkuvG6FEcb9MDTfsXdTYPGjYLRdeTEU= github.com/sashabaranov/go-openai v1.40.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y= @@ -106,6 +125,8 @@ github.com/temoto/robotstxt v1.1.2 h1:W2pOjSJ6SWvldyEuiFXNxz3xZ8aiWX5LbfDiOFd7Fx github.com/temoto/robotstxt v1.1.2/go.mod h1:+1AmkuG3IYkh1kv0d2qEB9Le88ehNO0zwOr3ujewlOo= github.com/tidwall/btree v1.6.0 h1:LDZfKfQIBHGHWSwckhXI0RPSXzlo+KYdjK7FWSqOzzg= github.com/tidwall/btree v1.6.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= +github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww= +github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= github.com/veqryn/slog-dedup v0.5.0 h1:2pc4va3q8p7Tor1SjVvi1ZbVK/oKNPgsqG15XFEt0iM= github.com/veqryn/slog-dedup v0.5.0/go.mod h1:/iQU008M3qFa5RovtfiHiODxJFvxZLjWRG/qf/zKFHw= github.com/xujiajun/mmap-go v1.0.1 h1:7Se7ss1fLPPRW+ePgqGpCkfGIZzJV6JPq9Wq9iv/WHc= @@ -123,6 +144,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= diff --git a/main.go b/main.go index b6d5216..fc56ede 100644 --- a/main.go +++ b/main.go @@ -47,6 +47,7 @@ import ( "github.com/glidea/zenfeed/pkg/storage/feed/block/index/primary" "github.com/glidea/zenfeed/pkg/storage/feed/block/index/vector" "github.com/glidea/zenfeed/pkg/storage/kv" + "github.com/glidea/zenfeed/pkg/storage/object" "github.com/glidea/zenfeed/pkg/telemetry/log" telemetryserver "github.com/glidea/zenfeed/pkg/telemetry/server" timeutil "github.com/glidea/zenfeed/pkg/util/time" @@ -122,18 +123,19 @@ type App struct { conf *config.App telemetry telemetryserver.Server - kvStorage kv.Storage - llmFactory llm.Factory - rewriter rewrite.Rewriter - feedStorage feed.Storage - api api.API - http http.Server - mcp mcp.Server - rss rss.Server - scraperMgr scrape.Manager - scheduler schedule.Scheduler - notifier notify.Notifier - notifyChan chan *rule.Result + kvStorage kv.Storage + llmFactory llm.Factory + rewriter rewrite.Rewriter + feedStorage feed.Storage + objectStorage object.Storage + api api.API + http http.Server + mcp mcp.Server + rss rss.Server + scraperMgr scrape.Manager + scheduler schedule.Scheduler + notifier notify.Notifier + notifyChan chan *rule.Result } // newApp creates a new application instance. @@ -164,6 +166,9 @@ func (a *App) setup() error { if err := a.setupKVStorage(); err != nil { return errors.Wrap(err, "setup kv storage") } + if err := a.setupObjectStorage(); err != nil { + return errors.Wrap(err, "setup object storage") + } if err := a.setupLLMFactory(); err != nil { return errors.Wrap(err, "setup llm factory") } @@ -251,7 +256,8 @@ func (a *App) setupLLMFactory() (err error) { // setupRewriter initializes the Rewriter factory. func (a *App) setupRewriter() (err error) { a.rewriter, err = rewrite.NewFactory().New(component.Global, a.conf, rewrite.Dependencies{ - LLMFactory: a.llmFactory, + LLMFactory: a.llmFactory, + ObjectStorage: a.objectStorage, }) if err != nil { return err @@ -282,6 +288,18 @@ func (a *App) setupFeedStorage() (err error) { return nil } +// setupObjectStorage initializes the Object storage. +func (a *App) setupObjectStorage() (err error) { + a.objectStorage, err = object.NewFactory().New(component.Global, a.conf, object.Dependencies{}) + if err != nil { + return err + } + + a.configMgr.Subscribe(a.objectStorage) + + return nil +} + // setupTelemetryServer initializes the Telemetry server. func (a *App) setupTelemetryServer() (err error) { a.telemetry, err = telemetryserver.NewFactory().New(component.Global, a.conf, telemetryserver.Dependencies{}) @@ -419,7 +437,7 @@ func (a *App) run(ctx context.Context) error { log.Info(ctx, "starting application components...") if err := component.Run(ctx, component.Group{a.configMgr}, - component.Group{a.llmFactory, a.telemetry}, + component.Group{a.llmFactory, a.objectStorage, a.telemetry}, component.Group{a.rewriter}, component.Group{a.feedStorage}, component.Group{a.kvStorage}, diff --git a/pkg/config/config.go b/pkg/config/config.go index 0848600..668d910 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -90,6 +90,7 @@ type LLM struct { APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty" desc:"The API key of the LLM. It is required when api.llm is set."` Model string `yaml:"model,omitempty" json:"model,omitempty" desc:"The model of the LLM. e.g. gpt-4o-mini. Can not be empty with embedding_model at same time when api.llm is set."` EmbeddingModel string `yaml:"embedding_model,omitempty" json:"embedding_model,omitempty" desc:"The embedding model of the LLM. e.g. text-embedding-3-small. Can not be empty with model at same time when api.llm is set. NOTE: Once used, do not modify it directly, instead, add a new LLM configuration."` + TTSModel string `yaml:"tts_model,omitempty" json:"tts_model,omitempty" desc:"The TTS model of the LLM."` Temperature float32 `yaml:"temperature,omitempty" json:"temperature,omitempty" desc:"The temperature (0-2) of the LLM. Default: 0.0"` } @@ -101,8 +102,9 @@ type Scrape struct { } type Storage struct { - Dir string `yaml:"dir,omitempty" json:"dir,omitempty" desc:"The base directory of the all storages. Default: ./data. It can not be changed after the app is running."` - Feed FeedStorage `yaml:"feed,omitempty" json:"feed,omitempty" desc:"The feed storage config."` + Dir string `yaml:"dir,omitempty" json:"dir,omitempty" desc:"The base directory of the all storages. Default: ./data. It can not be changed after the app is running."` + Feed FeedStorage `yaml:"feed,omitempty" json:"feed,omitempty" desc:"The feed storage config."` + Object ObjectStorage `yaml:"object,omitempty" json:"object,omitempty" desc:"The object storage config."` } type FeedStorage struct { @@ -113,6 +115,14 @@ type FeedStorage struct { BlockDuration timeutil.Duration `yaml:"block_duration,omitempty" json:"block_duration,omitempty" desc:"How long to keep the feed storage block. Block is time-based, like Prometheus TSDB Block. Default: 25h"` } +type ObjectStorage struct { + Endpoint string `yaml:"endpoint,omitempty" json:"endpoint,omitempty" desc:"The endpoint of the object storage."` + AccessKeyID string `yaml:"access_key_id,omitempty" json:"access_key_id,omitempty" desc:"The access key id of the object storage."` + SecretAccessKey string `yaml:"secret_access_key,omitempty" json:"secret_access_key,omitempty" desc:"The secret access key of the object storage."` + Bucket string `yaml:"bucket,omitempty" json:"bucket,omitempty" desc:"The bucket of the object storage."` + BucketURL string `yaml:"bucket_url,omitempty" json:"bucket_url,omitempty" desc:"The public URL of the object storage bucket."` +} + type ScrapeSource struct { Interval timeutil.Duration `yaml:"interval,omitempty" json:"interval,omitempty" desc:"How often to scrape this source. Default: global interval"` Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the source. It is required."` @@ -137,7 +147,8 @@ type RewriteRule struct { } type RewriteRuleTransform struct { - ToText *RewriteRuleTransformToText `yaml:"to_text,omitempty" json:"to_text,omitempty" desc:"The transform config to transform the source text to text."` + ToText *RewriteRuleTransformToText `yaml:"to_text,omitempty" json:"to_text,omitempty" desc:"The transform config to transform the source text to text."` + ToPodcast *RewriteRuleTransformToPodcast `yaml:"to_podcast,omitempty" json:"to_podcast,omitempty" desc:"The transform config to transform the source text to podcast."` } type RewriteRuleTransformToText struct { @@ -146,6 +157,20 @@ type RewriteRuleTransformToText struct { Prompt string `yaml:"prompt,omitempty" json:"prompt,omitempty" desc:"The prompt to transform the source text. The source text will be injected into the prompt above. And you can use go template syntax to refer some built-in prompts, like {{ .summary }}. Available built-in prompts: category, tags, score, comment_confucius, summary, summary_html_snippet."` } +type RewriteRuleTransformToPodcast struct { + LLM string `yaml:"llm,omitempty" json:"llm,omitempty" desc:"The LLM name to use. Default is the default LLM in llms section."` + EstimateMaximumDuration timeutil.Duration `yaml:"estimate_maximum_duration,omitempty" json:"estimate_maximum_duration,omitempty" desc:"The estimated maximum duration of the podcast. It will affect the length of the generated transcript. e.g. 5m. Default is 5m."` + TranscriptAdditionalPrompt string `yaml:"transcript_additional_prompt,omitempty" json:"transcript_additional_prompt,omitempty" desc:"The additional prompt to add to the transcript. It is optional."` + TTSLLM string `yaml:"tts_llm,omitempty" json:"tts_llm,omitempty" desc:"The LLM name to use for TTS. Only supports gemini now. Default is the default LLM in llms section."` + Speakers []RewriteRuleTransformToPodcastSpeaker `yaml:"speakers,omitempty" json:"speakers,omitempty" desc:"The speakers to use. It is required, at least one speaker is needed."` +} + +type RewriteRuleTransformToPodcastSpeaker struct { + Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the speaker. It is required."` + Role string `yaml:"role,omitempty" json:"role,omitempty" desc:"The role description of the speaker. You can think of it as a character setting."` + Voice string `yaml:"voice,omitempty" json:"voice,omitempty" desc:"The voice of the speaker. It is required."` +} + type SchedulsRule struct { Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the rule. It is required."` Query string `yaml:"query,omitempty" json:"query,omitempty" desc:"The semantic query to get the feeds. NOTE it is optional"` diff --git a/pkg/llm/gemini.go b/pkg/llm/gemini.go new file mode 100644 index 0000000..a219860 --- /dev/null +++ b/pkg/llm/gemini.go @@ -0,0 +1,248 @@ +// Copyright (C) 2025 wangyusong +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package llm + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "path/filepath" + + "github.com/pkg/errors" + oai "github.com/sashabaranov/go-openai" + + "github.com/glidea/zenfeed/pkg/component" + "github.com/glidea/zenfeed/pkg/telemetry" + telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model" + "github.com/glidea/zenfeed/pkg/util/wav" +) + +type gemini struct { + *component.Base[Config, struct{}] + text + hc *http.Client + + embeddingSpliter embeddingSpliter +} + +func newGemini(c *Config) LLM { + config := oai.DefaultConfig(c.APIKey) + config.BaseURL = filepath.Join(c.Endpoint, "openai") // OpenAI compatible endpoint. + client := oai.NewClientWithConfig(config) + embeddingSpliter := newEmbeddingSpliter(1536, 64) + + base := component.New(&component.BaseConfig[Config, struct{}]{ + Name: "LLM/gemini", + Instance: c.Name, + Config: c, + }) + + return &gemini{ + Base: base, + text: &openaiText{ + Base: base, + client: client, + }, + hc: &http.Client{}, + embeddingSpliter: embeddingSpliter, + } +} + +func (g *gemini) WAV(ctx context.Context, text string, speakers []Speaker) (r io.ReadCloser, err error) { + ctx = telemetry.StartWith(ctx, append(g.TelemetryLabels(), telemetrymodel.KeyOperation, "WAV")...) + defer func() { telemetry.End(ctx, err) }() + + if g.Config().TTSModel == "" { + return nil, errors.New("tts model is not set") + } + + reqPayload, err := buildWAVRequestPayload(text, speakers) + if err != nil { + return nil, errors.Wrap(err, "build wav request payload") + } + + pcmData, err := g.doWAVRequest(ctx, reqPayload) + if err != nil { + return nil, errors.Wrap(err, "do wav request") + } + + return streamWAV(pcmData), nil +} + +func (g *gemini) doWAVRequest(ctx context.Context, reqPayload *geminiRequest) ([]byte, error) { + config := g.Config() + body, err := json.Marshal(reqPayload) + if err != nil { + return nil, errors.Wrap(err, "marshal tts request") + } + + url := config.Endpoint + "/models/" + config.TTSModel + ":generateContent" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrap(err, "new tts request") + } + req.Header.Set("x-goog-api-key", config.APIKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := g.hc.Do(req) + if err != nil { + return nil, errors.Wrap(err, "do tts request") + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + errMsg, _ := io.ReadAll(resp.Body) + + return nil, errors.Errorf("tts request failed with status %d: %s", resp.StatusCode, string(errMsg)) + } + + var ttsResp geminiResponse + if err := json.NewDecoder(resp.Body).Decode(&ttsResp); err != nil { + return nil, errors.Wrap(err, "decode tts response") + } + if len(ttsResp.Candidates) == 0 || len(ttsResp.Candidates[0].Content.Parts) == 0 || ttsResp.Candidates[0].Content.Parts[0].InlineData == nil { + return nil, errors.New("no audio data in tts response") + } + + audioDataB64 := ttsResp.Candidates[0].Content.Parts[0].InlineData.Data + pcmData, err := base64.StdEncoding.DecodeString(audioDataB64) + if err != nil { + return nil, errors.Wrap(err, "decode base64") + } + + return pcmData, nil +} + +func buildWAVRequestPayload(text string, speakers []Speaker) (*geminiRequest, error) { + reqPayload := geminiRequest{ + Contents: []*geminiRequestContent{{Parts: []*geminiRequestPart{{Text: text}}}}, + Config: &geminiRequestConfig{ + ResponseModalities: []string{"AUDIO"}, + SpeechConfig: &geminiRequestSpeechConfig{}, + }, + } + + switch len(speakers) { + case 0: + return nil, errors.New("no speakers") + case 1: + reqPayload.Config.SpeechConfig.VoiceConfig = &geminiRequestVoiceConfig{ + PrebuiltVoiceConfig: &geminiRequestPrebuiltVoiceConfig{VoiceName: speakers[0].Voice}, + } + default: + multiSpeakerConfig := &geminiRequestMultiSpeakerVoiceConfig{} + for _, s := range speakers { + multiSpeakerConfig.SpeakerVoiceConfigs = append(multiSpeakerConfig.SpeakerVoiceConfigs, &geminiRequestSpeakerVoiceConfig{ + Speaker: s.Name, + VoiceConfig: &geminiRequestVoiceConfig{ + PrebuiltVoiceConfig: &geminiRequestPrebuiltVoiceConfig{VoiceName: s.Voice}, + }, + }) + } + reqPayload.Config.SpeechConfig.MultiSpeakerVoiceConfig = multiSpeakerConfig + } + + return &reqPayload, nil +} + +func streamWAV(pcmData []byte) io.ReadCloser { + pipeReader, pipeWriter := io.Pipe() + go func() { + defer func() { _ = pipeWriter.Close() }() + if err := wav.WriteHeader(pipeWriter, geminiWavHeader, uint32(len(pcmData))); err != nil { + pipeWriter.CloseWithError(errors.Wrap(err, "write wav header")) + + return + } + if _, err := io.Copy(pipeWriter, bytes.NewReader(pcmData)); err != nil { + pipeWriter.CloseWithError(errors.Wrap(err, "write pcm data")) + + return + } + }() + + return pipeReader +} + +var geminiWavHeader = &wav.Header{ + SampleRate: 24000, + BitDepth: 16, + NumChannels: 1, +} + +type geminiRequest struct { + Contents []*geminiRequestContent `json:"contents"` + Config *geminiRequestConfig `json:"generationConfig"` +} + +type geminiRequestContent struct { + Parts []*geminiRequestPart `json:"parts"` +} + +type geminiRequestPart struct { + Text string `json:"text"` +} + +type geminiRequestConfig struct { + ResponseModalities []string `json:"responseModalities"` + SpeechConfig *geminiRequestSpeechConfig `json:"speechConfig"` +} + +type geminiRequestSpeechConfig struct { + VoiceConfig *geminiRequestVoiceConfig `json:"voiceConfig,omitempty"` + MultiSpeakerVoiceConfig *geminiRequestMultiSpeakerVoiceConfig `json:"multiSpeakerVoiceConfig,omitempty"` +} + +type geminiRequestVoiceConfig struct { + PrebuiltVoiceConfig *geminiRequestPrebuiltVoiceConfig `json:"prebuiltVoiceConfig,omitempty"` +} + +type geminiRequestPrebuiltVoiceConfig struct { + VoiceName string `json:"voiceName,omitempty"` +} + +type geminiRequestMultiSpeakerVoiceConfig struct { + SpeakerVoiceConfigs []*geminiRequestSpeakerVoiceConfig `json:"speakerVoiceConfigs,omitempty"` +} + +type geminiRequestSpeakerVoiceConfig struct { + Speaker string `json:"speaker,omitempty"` + VoiceConfig *geminiRequestVoiceConfig `json:"voiceConfig,omitempty"` +} + +type geminiResponse struct { + Candidates []*geminiResponseCandidate `json:"candidates"` +} + +type geminiResponseCandidate struct { + Content *geminiResponseContent `json:"content"` +} + +type geminiResponseContent struct { + Parts []*geminiResponsePart `json:"parts"` +} + +type geminiResponsePart struct { + InlineData *geminiResponseInlineData `json:"inlineData"` +} + +type geminiResponseInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` // Base64 encoded. +} diff --git a/pkg/llm/llm.go b/pkg/llm/llm.go index 4b8028f..e248bba 100644 --- a/pkg/llm/llm.go +++ b/pkg/llm/llm.go @@ -18,6 +18,7 @@ package llm import ( "bytes" "context" + "io" "reflect" "strconv" "sync" @@ -42,19 +43,33 @@ import ( // --- Interface code block --- type LLM interface { component.Component + text + audio +} + +type text interface { String(ctx context.Context, messages []string) (string, error) EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error) Embedding(ctx context.Context, text string) ([]float32, error) } +type audio interface { + WAV(ctx context.Context, text string, speakers []Speaker) (io.ReadCloser, error) +} + +type Speaker struct { + Name string + Voice string +} + type Config struct { - Name string - Default bool - Provider ProviderType - Endpoint string - APIKey string - Model, EmbeddingModel string - Temperature float32 + Name string + Default bool + Provider ProviderType + Endpoint string + APIKey string + Model, EmbeddingModel, TTSModel string + Temperature float32 } type ProviderType string @@ -72,7 +87,7 @@ var defaultEndpoints = map[ProviderType]string{ ProviderTypeOpenAI: "https://api.openai.com/v1", ProviderTypeOpenRouter: "https://openrouter.ai/api/v1", ProviderTypeDeepSeek: "https://api.deepseek.com/v1", - ProviderTypeGemini: "https://generativelanguage.googleapis.com/v1beta/openai", + ProviderTypeGemini: "https://generativelanguage.googleapis.com/v1beta", ProviderTypeVolc: "https://ark.cn-beijing.volces.com/api/v3", ProviderTypeSiliconFlow: "https://api.siliconflow.cn/v1", } @@ -97,8 +112,8 @@ func (c *Config) Validate() error { //nolint:cyclop if c.APIKey == "" { return errors.New("api key is required") } - if c.Model == "" && c.EmbeddingModel == "" { - return errors.New("model or embedding model is required") + if c.Model == "" && c.EmbeddingModel == "" && c.TTSModel == "" { + return errors.New("model or embedding model or tts model is required") } if c.Temperature < 0 || c.Temperature > 2 { return errors.Errorf("invalid temperature: %f, should be in range [0, 2]", c.Temperature) @@ -182,6 +197,7 @@ func (c *FactoryConfig) From(app *config.App) { APIKey: llm.APIKey, Model: llm.Model, EmbeddingModel: llm.EmbeddingModel, + TTSModel: llm.TTSModel, Temperature: llm.Temperature, }) } @@ -207,12 +223,9 @@ func NewFactory( ) (Factory, error) { if len(mockOn) > 0 { mf := &mockFactory{} - getCall := mf.On("Get", mock.Anything) - getCall.Run(func(args mock.Arguments) { - m := &mockLLM{} - component.MockOptions(mockOn).Apply(&m.Mock) - getCall.Return(m, nil) - }) + m := &mockLLM{} + component.MockOptions(mockOn).Apply(&m.Mock) + mf.On("Get", mock.Anything).Return(m) mf.On("Reload", mock.Anything).Return(nil) return mf, nil @@ -307,11 +320,6 @@ func (f *factory) Get(name string) LLM { continue } - if f.llms[name] == nil { - llm := f.new(&llmC) - f.llms[name] = llm - } - return f.llms[name] } @@ -320,8 +328,12 @@ func (f *factory) Get(name string) LLM { func (f *factory) new(c *Config) LLM { switch c.Provider { - case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek, ProviderTypeGemini, ProviderTypeVolc, ProviderTypeSiliconFlow: //nolint:lll + case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek, ProviderTypeVolc, ProviderTypeSiliconFlow: //nolint:lll return newCached(newOpenAI(c), f.Dependencies().KVStorage) + + case ProviderTypeGemini: + return newCached(newGemini(c), f.Dependencies().KVStorage) + default: return newCached(newOpenAI(c), f.Dependencies().KVStorage) } @@ -333,14 +345,17 @@ func (f *factory) initLLMs() { llms = make(map[string]LLM, len(config.LLMs)) defaultLLM LLM ) + for _, llmC := range config.LLMs { llm := f.new(&llmC) + llms[llmC.Name] = llm if llmC.Name == config.defaultLLM { defaultLLM = llm } } + f.llms = llms f.defaultLLM = defaultLLM } @@ -482,12 +497,27 @@ func (m *mockLLM) String(ctx context.Context, messages []string) (string, error) func (m *mockLLM) EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error) { args := m.Called(ctx, labels) + if args.Error(1) != nil { + return nil, args.Error(1) + } return args.Get(0).([][]float32), args.Error(1) } func (m *mockLLM) Embedding(ctx context.Context, text string) ([]float32, error) { args := m.Called(ctx, text) + if args.Error(1) != nil { + return nil, args.Error(1) + } return args.Get(0).([]float32), args.Error(1) } + +func (m *mockLLM) WAV(ctx context.Context, text string, speakers []Speaker) (io.ReadCloser, error) { + args := m.Called(ctx, text, speakers) + if args.Error(1) != nil { + return nil, args.Error(1) + } + + return args.Get(0).(io.ReadCloser), args.Error(1) +} diff --git a/pkg/llm/openai.go b/pkg/llm/openai.go index 3b23804..4ae3537 100644 --- a/pkg/llm/openai.go +++ b/pkg/llm/openai.go @@ -18,6 +18,7 @@ package llm import ( "context" "encoding/json" + "io" "github.com/pkg/errors" oai "github.com/sashabaranov/go-openai" @@ -31,9 +32,7 @@ import ( type openai struct { *component.Base[Config, struct{}] - - client *oai.Client - embeddingSpliter embeddingSpliter + text } func newOpenAI(c *Config) LLM { @@ -42,18 +41,34 @@ func newOpenAI(c *Config) LLM { client := oai.NewClientWithConfig(config) embeddingSpliter := newEmbeddingSpliter(1536, 64) + base := component.New(&component.BaseConfig[Config, struct{}]{ + Name: "LLM/openai", + Instance: c.Name, + Config: c, + }) + return &openai{ - Base: component.New(&component.BaseConfig[Config, struct{}]{ - Name: "LLM/openai", - Instance: c.Name, - Config: c, - }), - client: client, - embeddingSpliter: embeddingSpliter, + Base: base, + text: &openaiText{ + Base: base, + client: client, + embeddingSpliter: embeddingSpliter, + }, } } -func (o *openai) String(ctx context.Context, messages []string) (value string, err error) { +func (o *openai) WAV(ctx context.Context, text string, speakers []Speaker) (r io.ReadCloser, err error) { + return nil, errors.New("not supported") +} + +type openaiText struct { + *component.Base[Config, struct{}] + + client *oai.Client + embeddingSpliter embeddingSpliter +} + +func (o *openaiText) String(ctx context.Context, messages []string) (value string, err error) { ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "String")...) defer func() { telemetry.End(ctx, err) }() @@ -91,7 +106,7 @@ func (o *openai) String(ctx context.Context, messages []string) (value string, e return resp.Choices[0].Message.Content, nil } -func (o *openai) EmbeddingLabels(ctx context.Context, labels model.Labels) (value [][]float32, err error) { +func (o *openaiText) EmbeddingLabels(ctx context.Context, labels model.Labels) (value [][]float32, err error) { ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "EmbeddingLabels")...) defer func() { telemetry.End(ctx, err) }() @@ -117,7 +132,7 @@ func (o *openai) EmbeddingLabels(ctx context.Context, labels model.Labels) (valu return vecs, nil } -func (o *openai) Embedding(ctx context.Context, s string) (value []float32, err error) { +func (o *openaiText) Embedding(ctx context.Context, s string) (value []float32, err error) { ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "Embedding")...) defer func() { telemetry.End(ctx, err) }() diff --git a/pkg/rewrite/rewrite.go b/pkg/rewrite/rewrite.go index 770a3c2..0312c90 100644 --- a/pkg/rewrite/rewrite.go +++ b/pkg/rewrite/rewrite.go @@ -17,9 +17,13 @@ package rewrite import ( "context" + "fmt" "html/template" + "io" "regexp" + "strconv" "strings" + "time" "unicode/utf8" "github.com/pkg/errors" @@ -30,10 +34,12 @@ import ( "github.com/glidea/zenfeed/pkg/llm" "github.com/glidea/zenfeed/pkg/llm/prompt" "github.com/glidea/zenfeed/pkg/model" + "github.com/glidea/zenfeed/pkg/storage/object" "github.com/glidea/zenfeed/pkg/telemetry" telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model" "github.com/glidea/zenfeed/pkg/util/buffer" "github.com/glidea/zenfeed/pkg/util/crawl" + hashutil "github.com/glidea/zenfeed/pkg/util/hash" ) // --- Interface code block --- @@ -68,7 +74,8 @@ func (c *Config) From(app *config.App) { } type Dependencies struct { - LLMFactory llm.Factory + LLMFactory llm.Factory // NOTE: String() with cache. + ObjectStorage object.Storage } type Rule struct { @@ -120,32 +127,98 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen } // Transform. - if r.Transform != nil { - if r.Transform.ToText == nil { - return errors.New("to_text is required when transform is set") + if r.Transform != nil { //nolint:nestif + if r.Transform.ToText != nil && r.Transform.ToPodcast != nil { + return errors.New("to_text and to_podcast can not be set at same time") + } + if r.Transform.ToText == nil && r.Transform.ToPodcast == nil { + return errors.New("either to_text or to_podcast must be set when transform is set") } - switch r.Transform.ToText.Type { - case ToTextTypePrompt: - if r.Transform.ToText.Prompt == "" { - return errors.New("to text prompt is required for prompt type") + if r.Transform.ToText != nil { + switch r.Transform.ToText.Type { + case ToTextTypePrompt: + if r.Transform.ToText.Prompt == "" { + return errors.New("to text prompt is required for prompt type") + } + tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt) + if err != nil { + return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt) + } + + buf := buffer.Get() + defer buffer.Put(buf) + if err := tmpl.Execute(buf, prompt.Builtin); err != nil { + return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt) + } + r.Transform.ToText.promptRendered = buf.String() + + case ToTextTypeCrawl, ToTextTypeCrawlByJina: + // No specific validation for crawl type here, as the source text itself is the URL. + default: + return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type) } - tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt) - if err != nil { - return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt) + } + + if r.Transform.ToPodcast != nil { + if len(r.Transform.ToPodcast.Speakers) == 0 { + return errors.New("at least one speaker is required for to_podcast") } - buf := buffer.Get() - defer buffer.Put(buf) - if err := tmpl.Execute(buf, prompt.Builtin); err != nil { - return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt) - } - r.Transform.ToText.promptRendered = buf.String() + r.Transform.ToPodcast.speakers = make([]llm.Speaker, len(r.Transform.ToPodcast.Speakers)) + var speakerDescs []string + var speakerNames []string + for i, s := range r.Transform.ToPodcast.Speakers { + if s.Name == "" { + return errors.New("speaker name is required") + } + if s.Voice == "" { + return errors.New("speaker voice is required") + } + r.Transform.ToPodcast.speakers[i] = llm.Speaker{Name: s.Name, Voice: s.Voice} - case ToTextTypeCrawl, ToTextTypeCrawlByJina: - // No specific validation for crawl type here, as the source text itself is the URL. - default: - return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type) + desc := s.Name + if s.Role != "" { + desc += " (" + s.Role + ")" + } + speakerDescs = append(speakerDescs, desc) + speakerNames = append(speakerNames, s.Name) + } + + speakersDesc := "- " + strings.Join(speakerDescs, "\n- ") + exampleSpeaker1 := speakerNames[0] + exampleSpeaker2 := exampleSpeaker1 + if len(speakerNames) > 1 { + exampleSpeaker2 = speakerNames[1] + } + + promptSegments := []string{ + "Please convert the following article into a podcast dialogue script.", + "The speakers are:\n" + speakersDesc, + } + + if r.Transform.ToPodcast.EstimateMaximumDuration > 0 { + wordsPerMinute := 200 + totalMinutes := int(r.Transform.ToPodcast.EstimateMaximumDuration.Minutes()) + estimatedWords := totalMinutes * wordsPerMinute + promptSegments = append(promptSegments, fmt.Sprintf("The script should be approximately %d words to fit within a %d-minute duration. If the original content is not sufficient, the script can be shorter as appropriate.", estimatedWords, totalMinutes)) + } + + if r.Transform.ToPodcast.TranscriptAdditionalPrompt != "" { + promptSegments = append(promptSegments, "Additional instructions: "+r.Transform.ToPodcast.TranscriptAdditionalPrompt) + } + + promptSegments = append(promptSegments, + "The output format MUST be a script where each line starts with the speaker's name followed by a colon and a space.", + "Do NOT include any other text, explanations, or formatting before or after the script.", + "Do NOT use background music in the script.", + "Do NOT include any greetings or farewells (e.g., 'Hello everyone', 'Welcome to our show', 'Goodbye').", + fmt.Sprintf("Example of the required format:\n%s: Today we are discussing the article's main points.\n%s: Let's start with the first one.", exampleSpeaker1, exampleSpeaker2), + "Now, convert the article.", + ) + + r.Transform.ToPodcast.transcriptPrompt = strings.Join(promptSegments, "\n\n") + r.Transform.ToPodcast.speakersDesc = speakersDesc } } @@ -160,9 +233,10 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen r.matchRE = re // Action. - switch r.Action { - case "": + if r.Action == "" { r.Action = ActionCreateOrUpdateLabel + } + switch r.Action { case ActionCreateOrUpdateLabel: if r.Label == "" { return errors.New("label is required for create or update label action") @@ -179,7 +253,7 @@ func (r *Rule) From(c *config.RewriteRule) { r.If = c.If r.SourceLabel = c.SourceLabel r.SkipTooShortThreshold = c.SkipTooShortThreshold - if c.Transform != nil { + if c.Transform != nil { //nolint:nestif t := &Transform{} if c.Transform.ToText != nil { toText := &ToText{ @@ -192,6 +266,25 @@ func (r *Rule) From(c *config.RewriteRule) { } t.ToText = toText } + if c.Transform.ToPodcast != nil { + toPodcast := &ToPodcast{ + LLM: c.Transform.ToPodcast.LLM, + EstimateMaximumDuration: time.Duration(c.Transform.ToPodcast.EstimateMaximumDuration), + TranscriptAdditionalPrompt: c.Transform.ToPodcast.TranscriptAdditionalPrompt, + TTSLLM: c.Transform.ToPodcast.TTSLLM, + } + if toPodcast.EstimateMaximumDuration == 0 { + toPodcast.EstimateMaximumDuration = 3 * time.Minute + } + for _, s := range c.Transform.ToPodcast.Speakers { + toPodcast.Speakers = append(toPodcast.Speakers, Speaker{ + Name: s.Name, + Role: s.Role, + Voice: s.Voice, + }) + } + t.ToPodcast = toPodcast + } r.Transform = t } r.Match = c.Match @@ -203,7 +296,8 @@ func (r *Rule) From(c *config.RewriteRule) { } type Transform struct { - ToText *ToText + ToText *ToText + ToPodcast *ToPodcast } type ToText struct { @@ -220,6 +314,24 @@ type ToText struct { promptRendered string } +type ToPodcast struct { + LLM string + EstimateMaximumDuration time.Duration + TranscriptAdditionalPrompt string + TTSLLM string + Speakers []Speaker + + transcriptPrompt string + speakersDesc string + speakers []llm.Speaker +} + +type Speaker struct { + Name string + Role string + Voice string +} + type ToTextType string const ( @@ -310,13 +422,9 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m } // Transform text if configured. - text := sourceText - if rule.Transform != nil && rule.Transform.ToText != nil { - transformed, err := r.transformText(ctx, rule.Transform, sourceText) - if err != nil { - return nil, errors.Wrap(err, "transform text") - } - text = transformed + text, err := r.transform(ctx, rule.Transform, sourceText) + if err != nil { + return nil, errors.Wrap(err, "transform") } // Check if text matches the rule. @@ -338,18 +446,34 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m return labels, nil } +func (r *rewriter) transform(ctx context.Context, transform *Transform, sourceText string) (string, error) { + if transform == nil { + return sourceText, nil + } + + if transform.ToText != nil { + return r.transformText(ctx, transform.ToText, sourceText) + } + + if transform.ToPodcast != nil { + return r.transformPodcast(ctx, transform.ToPodcast, sourceText) + } + + return sourceText, nil +} + // transformText transforms text using configured LLM or by crawling a URL. -func (r *rewriter) transformText(ctx context.Context, transform *Transform, text string) (string, error) { - switch transform.ToText.Type { +func (r *rewriter) transformText(ctx context.Context, toText *ToText, text string) (string, error) { + switch toText.Type { case ToTextTypeCrawl: return r.transformTextCrawl(ctx, r.crawler, text) case ToTextTypeCrawlByJina: return r.transformTextCrawl(ctx, r.jinaCrawler, text) case ToTextTypePrompt: - return r.transformTextPrompt(ctx, transform, text) + return r.transformTextPrompt(ctx, toText, text) default: - return r.transformTextPrompt(ctx, transform, text) + return r.transformTextPrompt(ctx, toText, text) } } @@ -363,13 +487,13 @@ func (r *rewriter) transformTextCrawl(ctx context.Context, crawler crawl.Crawler } // transformTextPrompt transforms text using configured LLM. -func (r *rewriter) transformTextPrompt(ctx context.Context, transform *Transform, text string) (string, error) { +func (r *rewriter) transformTextPrompt(ctx context.Context, toText *ToText, text string) (string, error) { // Get LLM instance. - llm := r.Dependencies().LLMFactory.Get(transform.ToText.LLM) + llm := r.Dependencies().LLMFactory.Get(toText.LLM) // Call completion. result, err := llm.String(ctx, []string{ - transform.ToText.promptRendered, + toText.promptRendered, text, // TODO: may place to first line to hit the model cache in different rewrite rules. }) if err != nil { @@ -388,6 +512,71 @@ func (r *rewriter) transformTextHack(text string) string { return text } +var audioKey = func(transcript, ext string) string { + hash := hashutil.Sum64(transcript) + file := strconv.FormatUint(hash, 10) + "." + ext + + return "podcasts/" + file +} + +func (r *rewriter) transformPodcast(ctx context.Context, toPodcast *ToPodcast, sourceText string) (url string, err error) { + transcript, err := r.generateTranscript(ctx, toPodcast, sourceText) + if err != nil { + return "", errors.Wrap(err, "generate podcast transcript") + } + + audioKey := audioKey(transcript, "wav") + url, err = r.Dependencies().ObjectStorage.Get(ctx, audioKey) + switch { + case err == nil: + // May canceled at last time by reload, fast return. + return url, nil + case errors.Is(err, object.ErrNotFound): + // Not found, generate new audio. + default: + return "", errors.Wrap(err, "get audio") + } + + audioStream, err := r.generateAudio(ctx, toPodcast, transcript) + if err != nil { + return "", errors.Wrap(err, "generate podcast audio") + } + defer func() { + if closeErr := audioStream.Close(); closeErr != nil { + err = errors.Wrap(err, "close audio stream") + } + }() + + url, err = r.Dependencies().ObjectStorage.Put(ctx, audioKey, audioStream, "audio/wav") + if err != nil { + return "", errors.Wrap(err, "store podcast audio") + } + + return url, nil +} + +func (r *rewriter) generateTranscript(ctx context.Context, toPodcast *ToPodcast, sourceText string) (string, error) { + transcript, err := r.Dependencies().LLMFactory.Get(toPodcast.LLM). + String(ctx, []string{toPodcast.transcriptPrompt, sourceText}) + if err != nil { + return "", errors.Wrap(err, "llm completion") + } + + return toPodcast.speakersDesc + + "\n\nFollowed by the actual dialogue script:\n" + + transcript, nil +} + +func (r *rewriter) generateAudio(ctx context.Context, toPodcast *ToPodcast, transcript string) (io.ReadCloser, error) { + audioStream, err := r.Dependencies().LLMFactory.Get(toPodcast.TTSLLM). + WAV(ctx, transcript, toPodcast.speakers) + if err != nil { + return nil, errors.Wrap(err, "calling tts llm") + } + + return audioStream, nil +} + type mockRewriter struct { component.Mock } diff --git a/pkg/rewrite/rewrite_test.go b/pkg/rewrite/rewrite_test.go index 636592c..9189f35 100644 --- a/pkg/rewrite/rewrite_test.go +++ b/pkg/rewrite/rewrite_test.go @@ -2,6 +2,8 @@ package rewrite import ( "context" + "io" + "strings" "testing" . "github.com/onsi/gomega" @@ -12,6 +14,7 @@ import ( "github.com/glidea/zenfeed/pkg/component" "github.com/glidea/zenfeed/pkg/llm" "github.com/glidea/zenfeed/pkg/model" + "github.com/glidea/zenfeed/pkg/storage/object" "github.com/glidea/zenfeed/pkg/test" ) @@ -19,8 +22,9 @@ func TestLabels(t *testing.T) { RegisterTestingT(t) type givenDetail struct { - config *Config - llmMock func(m *mock.Mock) + config *Config + llmMock func(m *mock.Mock) + objectStorageMock func(m *mock.Mock) } type whenDetail struct { inputLabels model.Labels @@ -173,7 +177,7 @@ func TestLabels(t *testing.T) { }, ThenExpected: thenExpected{ outputLabels: nil, - err: errors.New("transform text: llm completion: LLM failed"), + err: errors.New("transform: llm completion: LLM failed"), isErr: true, }, }, @@ -220,22 +224,163 @@ func TestLabels(t *testing.T) { isErr: false, }, }, + { + Scenario: "Successfully generate podcast from content", + Given: "a rule to convert content to a podcast with all dependencies mocked to succeed", + When: "processing labels with content to be converted to a podcast", + Then: "should return labels with a new podcast_url label", + GivenDetail: givenDetail{ + config: &Config{ + { + SourceLabel: model.LabelContent, + Transform: &Transform{ + ToPodcast: &ToPodcast{ + LLM: "mock-llm-transcript", + TTSLLM: "mock-llm-tts", + Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}, + }, + }, + Action: ActionCreateOrUpdateLabel, + Label: "podcast_url", + }, + }, + llmMock: func(m *mock.Mock) { + m.On("String", mock.Anything, mock.Anything).Return("This is the podcast script.", nil).Once() + m.On("WAV", mock.Anything, mock.Anything, mock.AnythingOfType("[]llm.Speaker")). + Return(io.NopCloser(strings.NewReader("fake audio data")), nil).Once() + }, + objectStorageMock: func(m *mock.Mock) { + m.On("Put", mock.Anything, mock.AnythingOfType("string"), mock.Anything, "audio/wav"). + Return("http://storage.example.com/podcast.wav", nil).Once() + m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once() + }, + }, + WhenDetail: whenDetail{ + inputLabels: model.Labels{ + {Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."}, + }, + }, + ThenExpected: thenExpected{ + outputLabels: model.Labels{ + {Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."}, + {Key: "podcast_url", Value: "http://storage.example.com/podcast.wav"}, + }, + isErr: false, + }, + }, + { + Scenario: "Fail podcast generation due to transcription LLM error", + Given: "a rule to convert content to a podcast, but the transcription LLM is mocked to fail", + When: "processing labels", + Then: "should return an error related to transcription failure", + GivenDetail: givenDetail{ + config: &Config{ + { + SourceLabel: model.LabelContent, + Transform: &Transform{ + ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}}, + }, + Action: ActionCreateOrUpdateLabel, Label: "podcast_url", + }, + }, + llmMock: func(m *mock.Mock) { + m.On("String", mock.Anything, mock.Anything).Return("", errors.New("transcript failed")).Once() + }, + }, + WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}}, + ThenExpected: thenExpected{ + outputLabels: nil, + err: errors.New("transform: generate podcast transcript: llm completion: transcript failed"), + isErr: true, + }, + }, + { + Scenario: "Fail podcast generation due to TTS LLM error", + Given: "a rule to convert content to a podcast, but the TTS LLM is mocked to fail", + When: "processing labels", + Then: "should return an error related to TTS failure", + GivenDetail: givenDetail{ + config: &Config{ + { + SourceLabel: model.LabelContent, + Transform: &Transform{ + ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}}, + }, + Action: ActionCreateOrUpdateLabel, Label: "podcast_url", + }, + }, + llmMock: func(m *mock.Mock) { + m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once() + m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("tts failed")).Once() + }, + objectStorageMock: func(m *mock.Mock) { + m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once() + }, + }, + WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}}, + ThenExpected: thenExpected{ + outputLabels: nil, + err: errors.New("transform: generate podcast audio: calling tts llm: tts failed"), + isErr: true, + }, + }, + { + Scenario: "Fail podcast generation due to object storage error", + Given: "a rule to convert content to a podcast, but object storage is mocked to fail", + When: "processing labels", + Then: "should return an error related to storage failure", + GivenDetail: givenDetail{ + config: &Config{ + { + SourceLabel: model.LabelContent, + Transform: &Transform{ + ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}}, + }, + Action: ActionCreateOrUpdateLabel, Label: "podcast_url", + }, + }, + llmMock: func(m *mock.Mock) { + m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once() + m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(io.NopCloser(strings.NewReader("fake audio")), nil).Once() + }, + objectStorageMock: func(m *mock.Mock) { + m.On("Put", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return("", errors.New("storage failed")).Once() + m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once() + }, + }, + WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}}, + ThenExpected: thenExpected{ + outputLabels: nil, + err: errors.New("transform: store podcast audio: storage failed"), + isErr: true, + }, + }, } for _, tt := range tests { t.Run(tt.Scenario, func(t *testing.T) { // Given. var mockLLMFactory llm.Factory - var mockInstance *mock.Mock // Store the mock instance for assertion - - // Create mock factory and capture the mock.Mock instance. - mockOption := component.MockOption(func(m *mock.Mock) { - mockInstance = m // Capture the mock instance. + var mockLLMInstance *mock.Mock + llmMockOption := component.MockOption(func(m *mock.Mock) { + mockLLMInstance = m if tt.GivenDetail.llmMock != nil { tt.GivenDetail.llmMock(m) } }) - mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, mockOption) // Use the factory directly with the option + mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, llmMockOption) + Expect(err).NotTo(HaveOccurred()) + + var mockObjectStorage object.Storage + var mockObjectStorageInstance *mock.Mock + objectStorageMockOption := component.MockOption(func(m *mock.Mock) { + mockObjectStorageInstance = m + if tt.GivenDetail.objectStorageMock != nil { + tt.GivenDetail.objectStorageMock(m) + } + }) + mockObjectStorageFactory := object.NewFactory(objectStorageMockOption) + mockObjectStorage, err = mockObjectStorageFactory.New("test", nil, object.Dependencies{}) Expect(err).NotTo(HaveOccurred()) // Manually validate config to compile regex and render templates. @@ -252,7 +397,8 @@ func TestLabels(t *testing.T) { Instance: "test", Config: tt.GivenDetail.config, Dependencies: Dependencies{ - LLMFactory: mockLLMFactory, // Pass the mock factory + LLMFactory: mockLLMFactory, // Pass the mock factory + ObjectStorage: mockObjectStorage, }, }), } @@ -280,10 +426,12 @@ func TestLabels(t *testing.T) { Expect(outputLabels).To(Equal(tt.ThenExpected.outputLabels)) } - // Verify LLM calls if stubs were provided. - if tt.GivenDetail.llmMock != nil && mockInstance != nil { - // Assert expectations on the captured mock instance. - mockInstance.AssertExpectations(t) + // Verify mock calls if stubs were provided. + if tt.GivenDetail.llmMock != nil && mockLLMInstance != nil { + mockLLMInstance.AssertExpectations(t) + } + if tt.GivenDetail.objectStorageMock != nil && mockObjectStorageInstance != nil { + mockObjectStorageInstance.AssertExpectations(t) } }) } diff --git a/pkg/scrape/manager.go b/pkg/scrape/manager.go index ae93379..0c55454 100644 --- a/pkg/scrape/manager.go +++ b/pkg/scrape/manager.go @@ -216,6 +216,10 @@ func (m *manager) reload(config *Config) (err error) { func (m *manager) runOrRestartScrapers(config *Config, newScrapers map[string]scraper.Scraper) error { for i := range config.Scrapers { c := &config.Scrapers[i] + if err := c.Validate(); err != nil { + return errors.Wrapf(err, "validate scraper %s", c.Name) + } + if err := m.runOrRestartScraper(c, newScrapers); err != nil { return errors.Wrapf(err, "run or restart scraper %s", c.Name) } diff --git a/pkg/scrape/scraper/scraper.go b/pkg/scrape/scraper/scraper.go index cc37f80..9583638 100644 --- a/pkg/scrape/scraper/scraper.go +++ b/pkg/scrape/scraper/scraper.go @@ -69,6 +69,11 @@ func (c *Config) Validate() error { if c.Name == "" { return errors.New("name cannot be empty") } + if c.RSS != nil { + if err := c.RSS.Validate(); err != nil { + return errors.Wrap(err, "invalid RSS config") + } + } return nil } diff --git a/pkg/scrape/scraper/scraper_test.go b/pkg/scrape/scraper/scraper_test.go index 45dc0ac..7000529 100644 --- a/pkg/scrape/scraper/scraper_test.go +++ b/pkg/scrape/scraper/scraper_test.go @@ -244,7 +244,7 @@ func TestNew(t *testing.T) { WhenDetail: whenDetail{}, ThenExpected: thenExpected{ isErr: true, - wantErrMsg: "creating source: invalid RSS config: URL must be a valid HTTP/HTTPS URL", // Error from newRSSReader via newReader + wantErrMsg: "invalid RSS config: URL must be a valid HTTP/HTTPS URL", // Error from newRSSReader via newReader }, }, { @@ -264,7 +264,7 @@ func TestNew(t *testing.T) { WhenDetail: whenDetail{}, ThenExpected: thenExpected{ isErr: true, - wantErrMsg: "creating source: source not supported", // Error from newReader + wantErrMsg: "source not supported", // Error from newReader }, }, } diff --git a/pkg/storage/object/object.go b/pkg/storage/object/object.go new file mode 100644 index 0000000..4a26b52 --- /dev/null +++ b/pkg/storage/object/object.go @@ -0,0 +1,239 @@ +// Copyright (C) 2025 wangyusong +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package object + +import ( + "context" + "io" + "net/url" + "reflect" + "strings" + + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + "github.com/pkg/errors" + + "github.com/glidea/zenfeed/pkg/component" + "github.com/glidea/zenfeed/pkg/config" + "github.com/glidea/zenfeed/pkg/telemetry" + "github.com/glidea/zenfeed/pkg/telemetry/log" + telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model" +) + +// --- Interface code block --- +type Storage interface { + component.Component + config.Watcher + Put(ctx context.Context, key string, body io.Reader, contentType string) (url string, err error) + Get(ctx context.Context, key string) (url string, err error) +} + +var ErrNotFound = errors.New("not found") + +type Config struct { + Endpoint string + AccessKeyID string + SecretAccessKey string + Bucket string + BucketURL string +} + +func (c *Config) Validate() error { + if c.Endpoint == "" { + return errors.New("endpoint is required") + } + c.Endpoint = strings.TrimPrefix(c.Endpoint, "https://") // S3 endpoint should not have https:// prefix. + c.Endpoint = strings.TrimPrefix(c.Endpoint, "http://") + + if c.AccessKeyID == "" { + return errors.New("access key id is required") + } + if c.SecretAccessKey == "" { + return errors.New("secret access key is required") + } + if c.Bucket == "" { + return errors.New("bucket is required") + } + if c.BucketURL == "" { + return errors.New("bucket url is required") + } + + return nil +} + +func (c *Config) From(app *config.App) *Config { + *c = Config{ + Endpoint: app.Storage.Object.Endpoint, + AccessKeyID: app.Storage.Object.AccessKeyID, + SecretAccessKey: app.Storage.Object.SecretAccessKey, + Bucket: app.Storage.Object.Bucket, + BucketURL: app.Storage.Object.BucketURL, + } + + return c +} + +type Dependencies struct{} + +// --- Factory code block --- +type Factory component.Factory[Storage, config.App, Dependencies] + +func NewFactory(mockOn ...component.MockOption) Factory { + if len(mockOn) > 0 { + return component.FactoryFunc[Storage, config.App, Dependencies]( + func(instance string, config *config.App, dependencies Dependencies) (Storage, error) { + m := &mockStorage{} + component.MockOptions(mockOn).Apply(&m.Mock) + + return m, nil + }, + ) + } + + return component.FactoryFunc[Storage, config.App, Dependencies](new) +} + +func new(instance string, app *config.App, dependencies Dependencies) (Storage, error) { + config := &Config{} + config.From(app) + if err := config.Validate(); err != nil { + return nil, errors.Wrap(err, "validate config") + } + + client, err := minio.New(config.Endpoint, &minio.Options{ + Creds: credentials.NewStaticV4(config.AccessKeyID, config.SecretAccessKey, ""), + Secure: true, + }) + if err != nil { + return nil, errors.Wrap(err, "new minio client") + } + + u, err := url.Parse(config.BucketURL) + if err != nil { + return nil, errors.Wrap(err, "parse public url") + } + + return &s3{ + Base: component.New(&component.BaseConfig[Config, Dependencies]{ + Name: "ObjectStorage", + Instance: instance, + Config: config, + Dependencies: dependencies, + }), + client: client, + bucketURL: u, + }, nil +} + +// --- Implementation code block --- +type s3 struct { + *component.Base[Config, Dependencies] + + client *minio.Client + bucketURL *url.URL +} + +func (s *s3) Put(ctx context.Context, key string, body io.Reader, contentType string) (publicURL string, err error) { + ctx = telemetry.StartWith(ctx, append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Put")...) + defer func() { telemetry.End(ctx, err) }() + bucket := s.Config().Bucket + + if _, err := s.client.PutObject(ctx, bucket, key, body, -1, minio.PutObjectOptions{ + ContentType: contentType, + }); err != nil { + return "", errors.Wrap(err, "put object") + } + + return s.bucketURL.JoinPath(key).String(), nil +} + +func (s *s3) Get(ctx context.Context, key string) (publicURL string, err error) { + ctx = telemetry.StartWith(ctx, append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Get")...) + defer func() { telemetry.End(ctx, err) }() + bucket := s.Config().Bucket + + if _, err := s.client.StatObject(ctx, bucket, key, minio.StatObjectOptions{}); err != nil { + errResponse := minio.ToErrorResponse(err) + if errResponse.Code == minio.NoSuchKey { + return "", ErrNotFound + } + + return "", errors.Wrap(err, "stat object") + } + + return s.bucketURL.JoinPath(key).String(), nil +} + +func (s *s3) Reload(app *config.App) (err error) { + ctx := telemetry.StartWith(s.Context(), append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Reload")...) + defer func() { telemetry.End(ctx, err) }() + + newConfig := &Config{} + newConfig.From(app) + if err := newConfig.Validate(); err != nil { + return errors.Wrap(err, "validate config") + } + + if reflect.DeepEqual(s.Config(), newConfig) { + log.Debug(ctx, "object storage config not changed") + + return nil + } + + client, err := minio.New(newConfig.Endpoint, &minio.Options{ + Creds: credentials.NewStaticV4(newConfig.AccessKeyID, newConfig.SecretAccessKey, ""), + Secure: true, + }) + if err != nil { + return errors.Wrap(err, "new minio client") + } + + u, err := url.Parse(newConfig.BucketURL) + if err != nil { + return errors.Wrap(err, "parse public url") + } + + s.client = client + s.bucketURL = u + s.SetConfig(newConfig) + + log.Info(ctx, "object storage reloaded") + + return nil +} + +// --- Mock code block --- +type mockStorage struct { + component.Mock +} + +func (m *mockStorage) Put(ctx context.Context, key string, body io.Reader, contentType string) (string, error) { + args := m.Called(ctx, key, body, contentType) + + return args.String(0), args.Error(1) +} + +func (m *mockStorage) Get(ctx context.Context, key string) (string, error) { + args := m.Called(ctx, key) + + return args.String(0), args.Error(1) +} + +func (m *mockStorage) Reload(app *config.App) error { + args := m.Called(app) + + return args.Error(0) +} diff --git a/pkg/util/binary/binary.go b/pkg/util/binary/binary.go index d9a3016..a54dd27 100644 --- a/pkg/util/binary/binary.go +++ b/pkg/util/binary/binary.go @@ -122,6 +122,32 @@ func ReadUint32(r io.Reader) (uint32, error) { return binary.LittleEndian.Uint32(b), nil } +// WriteUint16 writes a uint16 using a pooled buffer. +func WriteUint16(w io.Writer, v uint16) error { + bp := smallBufPool.Get().(*[]byte) + defer smallBufPool.Put(bp) + b := *bp + + binary.LittleEndian.PutUint16(b, v) + _, err := w.Write(b[:2]) + + return err +} + +// ReadUint16 reads a uint16 using a pooled buffer. +func ReadUint16(r io.Reader) (uint16, error) { + bp := smallBufPool.Get().(*[]byte) + defer smallBufPool.Put(bp) + b := (*bp)[:2] + + // Read exactly 2 bytes into the slice. + if _, err := io.ReadFull(r, b); err != nil { + return 0, errors.Wrap(err, "read uint16") + } + + return binary.LittleEndian.Uint16(b), nil +} + // WriteFloat32 writes a float32 using a pooled buffer. func WriteFloat32(w io.Writer, v float32) error { return WriteUint32(w, math.Float32bits(v)) diff --git a/pkg/util/wav/wav.go b/pkg/util/wav/wav.go new file mode 100644 index 0000000..cb3e7db --- /dev/null +++ b/pkg/util/wav/wav.go @@ -0,0 +1,100 @@ +// Copyright (C) 2025 wangyusong +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package wav + +import ( + "io" + + "github.com/pkg/errors" + + binaryutil "github.com/glidea/zenfeed/pkg/util/binary" +) + +// Header contains the WAV header information. +type Header struct { + SampleRate uint32 + BitDepth uint16 + NumChannels uint16 +} + +// WriteHeader writes the WAV header to a writer. +// pcmDataSize is the size of the raw PCM data. +func WriteHeader(w io.Writer, h *Header, pcmDataSize uint32) error { + // RIFF Header. + if err := writeRIFFHeader(w, pcmDataSize); err != nil { + return errors.Wrap(err, "write RIFF header") + } + + // fmt chunk. + if err := writeFMTChunk(w, h); err != nil { + return errors.Wrap(err, "write fmt chunk") + } + + // data chunk. + if _, err := w.Write([]byte("data")); err != nil { + return errors.Wrap(err, "write data chunk marker") + } + if err := binaryutil.WriteUint32(w, pcmDataSize); err != nil { + return errors.Wrap(err, "write pcm data size") + } + + return nil +} + +func writeRIFFHeader(w io.Writer, pcmDataSize uint32) error { + if _, err := w.Write([]byte("RIFF")); err != nil { + return errors.Wrap(err, "write RIFF") + } + if err := binaryutil.WriteUint32(w, uint32(36+pcmDataSize)); err != nil { + return errors.Wrap(err, "write file size") + } + if _, err := w.Write([]byte("WAVE")); err != nil { + return errors.Wrap(err, "write WAVE") + } + + return nil +} + +func writeFMTChunk(w io.Writer, h *Header) error { + if _, err := w.Write([]byte("fmt ")); err != nil { + return errors.Wrap(err, "write fmt") + } + if err := binaryutil.WriteUint32(w, uint32(16)); err != nil { // PCM chunk size. + return errors.Wrap(err, "write pcm chunk size") + } + if err := binaryutil.WriteUint16(w, uint16(1)); err != nil { // PCM format. + return errors.Wrap(err, "write pcm format") + } + if err := binaryutil.WriteUint16(w, h.NumChannels); err != nil { + return errors.Wrap(err, "write num channels") + } + if err := binaryutil.WriteUint32(w, h.SampleRate); err != nil { + return errors.Wrap(err, "write sample rate") + } + byteRate := h.SampleRate * uint32(h.NumChannels) * uint32(h.BitDepth) / 8 + if err := binaryutil.WriteUint32(w, byteRate); err != nil { + return errors.Wrap(err, "write byte rate") + } + blockAlign := h.NumChannels * h.BitDepth / 8 + if err := binaryutil.WriteUint16(w, blockAlign); err != nil { + return errors.Wrap(err, "write block align") + } + if err := binaryutil.WriteUint16(w, h.BitDepth); err != nil { + return errors.Wrap(err, "write bit depth") + } + + return nil +} diff --git a/pkg/util/wav/wav_test.go b/pkg/util/wav/wav_test.go new file mode 100644 index 0000000..393f270 --- /dev/null +++ b/pkg/util/wav/wav_test.go @@ -0,0 +1,161 @@ +// Copyright (C) 2025 wangyusong +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package wav + +import ( + "bytes" + "testing" + + . "github.com/onsi/gomega" + + "github.com/glidea/zenfeed/pkg/test" +) + +func TestWriteHeader(t *testing.T) { + RegisterTestingT(t) + + type givenDetail struct{} + + type whenDetail struct { + header *Header + pcmDataSize uint32 + } + + type thenExpected struct { + expectedBytes []byte + expectError bool + } + + tests := []test.Case[givenDetail, whenDetail, thenExpected]{ + { + Scenario: "Standard CD quality audio", + Given: "a header for CD quality audio and a non-zero data size", + When: "writing the header", + Then: "should produce a valid 44-byte WAV header and no error", + GivenDetail: givenDetail{}, + WhenDetail: whenDetail{ + header: &Header{ + SampleRate: 44100, + BitDepth: 16, + NumChannels: 2, + }, + pcmDataSize: 176400, + }, + ThenExpected: thenExpected{ + expectedBytes: []byte{ + 'R', 'I', 'F', 'F', + 0x34, 0xB1, 0x02, 0x00, // ChunkSize = 36 + 176400 = 176436 + 'W', 'A', 'V', 'E', + 'f', 'm', 't', ' ', + 0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16 + 0x01, 0x00, // AudioFormat = 1 (PCM) + 0x02, 0x00, // NumChannels = 2 + 0x44, 0xAC, 0x00, 0x00, // SampleRate = 44100 + 0x10, 0xB1, 0x02, 0x00, // ByteRate = 176400 + 0x04, 0x00, // BlockAlign = 4 + 0x10, 0x00, // BitsPerSample = 16 + 'd', 'a', 't', 'a', + 0x10, 0xB1, 0x02, 0x00, // Subchunk2Size = 176400 + }, + expectError: false, + }, + }, + { + Scenario: "Mono audio for speech", + Given: "a header for mono speech audio and a non-zero data size", + When: "writing the header", + Then: "should produce a valid 44-byte WAV header and no error", + GivenDetail: givenDetail{}, + WhenDetail: whenDetail{ + header: &Header{ + SampleRate: 16000, + BitDepth: 16, + NumChannels: 1, + }, + pcmDataSize: 32000, + }, + ThenExpected: thenExpected{ + expectedBytes: []byte{ + 'R', 'I', 'F', 'F', + 0x24, 0x7D, 0x00, 0x00, // ChunkSize = 36 + 32000 = 32036 + 'W', 'A', 'V', 'E', + 'f', 'm', 't', ' ', + 0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16 + 0x01, 0x00, // AudioFormat = 1 + 0x01, 0x00, // NumChannels = 1 + 0x80, 0x3E, 0x00, 0x00, // SampleRate = 16000 + 0x00, 0x7D, 0x00, 0x00, // ByteRate = 32000 + 0x02, 0x00, // BlockAlign = 2 + 0x10, 0x00, // BitsPerSample = 16 + 'd', 'a', 't', 'a', + 0x00, 0x7D, 0x00, 0x00, // Subchunk2Size = 32000 + }, + expectError: false, + }, + }, + { + Scenario: "8-bit mono audio with zero data size", + Given: "a header for 8-bit mono audio and a zero data size", + When: "writing the header for an empty file", + Then: "should produce a valid 44-byte WAV header with data size 0", + GivenDetail: givenDetail{}, + WhenDetail: whenDetail{ + header: &Header{ + SampleRate: 8000, + BitDepth: 8, + NumChannels: 1, + }, + pcmDataSize: 0, + }, + ThenExpected: thenExpected{ + expectedBytes: []byte{ + 'R', 'I', 'F', 'F', + 0x24, 0x00, 0x00, 0x00, // ChunkSize = 36 + 0 = 36 + 'W', 'A', 'V', 'E', + 'f', 'm', 't', ' ', + 0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16 + 0x01, 0x00, // AudioFormat = 1 + 0x01, 0x00, // NumChannels = 1 + 0x40, 0x1F, 0x00, 0x00, // SampleRate = 8000 + 0x40, 0x1F, 0x00, 0x00, // ByteRate = 8000 + 0x01, 0x00, // BlockAlign = 1 + 0x08, 0x00, // BitsPerSample = 8 + 'd', 'a', 't', 'a', + 0x00, 0x00, 0x00, 0x00, // Subchunk2Size = 0 + }, + expectError: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.Scenario, func(t *testing.T) { + // Given. + var buf bytes.Buffer + + // When. + err := WriteHeader(&buf, tt.WhenDetail.header, tt.WhenDetail.pcmDataSize) + + // Then. + if tt.ThenExpected.expectError { + Expect(err).To(HaveOccurred()) + } else { + Expect(err).NotTo(HaveOccurred()) + Expect(buf.Bytes()).To(Equal(tt.ThenExpected.expectedBytes)) + } + }) + } +}