add podcast

This commit is contained in:
glidea
2025-07-08 18:13:26 +08:00
parent 263fcbbfaf
commit 00c5dfadee
21 changed files with 1549 additions and 145 deletions

View File

@@ -2,11 +2,9 @@ name: CI
on:
push:
branches: [ main ]
branches: [ main, dev ]
pull_request:
branches: [ main ]
release:
types: [ published ]
jobs:
test:
@@ -27,7 +25,7 @@ jobs:
build-and-push:
runs-on: ubuntu-latest
needs: test
if: github.event_name == 'release'
if: github.event_name == 'push'
steps:
- uses: actions/checkout@v4
- name: Set up Docker Buildx
@@ -37,5 +35,9 @@ jobs:
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push Docker images
run: make push
- name: Build and push Docker image (main)
if: github.ref_name == 'main'
run: make push
- name: Build and push Docker image (dev)
if: github.ref_name == 'dev'
run: make dev-push

2
.gitignore vendored
View File

@@ -18,7 +18,7 @@ local_docs/
.env
.env.local
__debug_bin
config.yaml
config.*yaml
data/
*debug*
.cursorrules

View File

@@ -41,6 +41,7 @@
| `llms[].api_key` | `string` | LLM 的 API 密钥。 | | 是 |
| `llms[].model` | `string` | LLM 的模型。例如 `gpt-4o-mini`。如果用于生成任务 (如总结),则不能为空。如果此 LLM 被使用,则不能与 `embedding_model` 同时为空。 | | 条件性必需 |
| `llms[].embedding_model` | `string` | LLM 的 Embedding 模型。例如 `text-embedding-3-small`。如果用于 Embedding则不能为空。如果此 LLM 被使用,则不能与 `model` 同时为空。**注意:** 初次使用后请勿直接修改,应添加新的 LLM 配置。 | | 条件性必需 |
| `llms[].tts_model` | `string` | LLM 的文本转语音 (TTS) 模型。 | | 否 |
| `llms[].temperature` | `float32` | LLM 的温度 (0-2)。 | `0.0` | 否 |
### Jina AI 配置 (`jina`)
@@ -80,10 +81,11 @@
### 存储配置 (`storage`)
| 字段 | 类型 | 描述 | 默认值 | 是否必需 |
| :------------- | :------- | :-------------------------------------------- | :----------- | :------- |
| `storage.dir` | `string` | 所有存储的基础目录。应用运行后不可更改。 | `./data` | 否 |
| `storage.feed` | `object` | Feed 存储配置。详见下方的 **Feed 存储配置**。 | (见具体字段) | 否 |
| 字段 | 类型 | 描述 | 默认值 | 是否必需 |
| :--------------- | :------- | :-------------------------------------------------------------- | :----------- | :------- |
| `storage.dir` | `string` | 所有存储的基础目录。应用运行后不可更改。 | `./data` | 否 |
| `storage.feed` | `object` | Feed 存储配置。详见下方的 **Feed 存储配置** | (见具体字段) | 否 |
| `storage.object` | `object` | 对象存储配置,用于存储播客等文件。详见下方的 **对象存储配置**。 | (见具体字段) | 否 |
### Feed 存储配置 (`storage.feed`)
@@ -95,6 +97,16 @@
| `storage.feed.retention` | `time.Duration` | Feed 的保留时长。 | `8d` | 否 |
| `storage.feed.block_duration` | `time.Duration` | 每个基于时间的 Feed 存储块的保留时长 (类似于 Prometheus TSDB Block)。 | `25h` | 否 |
### 对象存储配置 (`storage.object`)
| 字段 | 类型 | 描述 | 默认值 | 是否必需 |
| :--------------------------------- | :------- | :----------------------------- | :----- | :-------------------- |
| `storage.object.endpoint` | `string` | 对象存储的端点。 | | 是 (如果使用播客功能) |
| `storage.object.access_key_id` | `string` | 对象存储的 Access Key ID。 | | 是 (如果使用播客功能) |
| `storage.object.secret_access_key` | `string` | 对象存储的 Secret Access Key。 | | 是 (如果使用播客功能) |
| `storage.object.bucket` | `string` | 对象存储的存储桶名称。 | | 是 (如果使用播客功能) |
| `storage.object.bucket_url` | `string` | 对象存储的桶访问 URL。 | | 否 |
### 重写规则配置 (`storage.feed.rewrites[]`)
定义在存储前处理 Feed 的规则。规则按顺序应用。
@@ -109,12 +121,8 @@
| `...rewrites[].match_re` | `string` | 用于匹配 (转换后) 文本的正则表达式。 | `.*` (匹配所有) | 否 (使用 `match``match_re`) |
| `...rewrites[].action` | `string` | 匹配时执行的操作: `create_or_update_label` (使用匹配/转换后的文本添加/更新标签), `drop_feed` (完全丢弃该 Feed)。 | `create_or_update_label` | 否 |
| `...rewrites[].label` | `string` | 要创建或更新的 Feed 标签名称。 | | 是 (如果 `action``create_or_update_label`) |
### 重写规则转换配置 (`storage.feed.rewrites[].transform`)
| 字段 | 类型 | 描述 | 默认值 | 是否必需 |
| :--------------------- | :------- | :------------------------------------------------------------------- | :----- | :------- |
| `...transform.to_text` | `object` | 使用 LLM 将源文本转换为文本。详见下方的 **重写规则转换为文本配置**。 | `nil` | 否 |
| `...transform.to_text` | `object` | 使用 LLM 将源文本转换为文本。详见下方的 **重写规则转换为文本配置**。 | `nil` | 否 |
| `...transform.to_podcast` | `object` | 将源文本转换为播客。详见下方的 **重写规则转换为播客配置**。 | `nil` | 否 |
### 重写规则转换为文本配置 (`storage.feed.rewrites[].transform.to_text`)
@@ -126,6 +134,25 @@
| `...to_text.llm` | `string` | **仅当 `type` 为 `prompt` 时有效。** 用于转换的 LLM 名称 (来自 `llms` 部分)。如果未指定,将使用在 `llms` 部分中标记为 `default: true` 的 LLM。 | `llms` 部分中的默认 LLM | 否 |
| `...to_text.prompt` | `string` | **仅当 `type` 为 `prompt` 时有效。** 用于转换的 Prompt。源文本将被注入。可以使用 Go 模板语法引用内置 Prompt: `{{ .summary }}`, `{{ .category }}`, `{{ .tags }}`, `{{ .score }}`, `{{ .comment_confucius }}`, `{{ .summary_html_snippet }}`, `{{ .summary_html_snippet_for_small_model }}`。 | | 是 (如果 `type``prompt`) |
### 重写规则转换为播客配置 (`storage.feed.rewrites[].transform.to_podcast`)
此配置定义了如何将 `source_label` 的文本转换为播客。
| 字段 | 类型 | 描述 | 默认值 | 是否必需 |
| :------------------------------------------- | :--------- | :-------------------------------------------------------------------------------------------------------- | :---------------------- | :------- |
| `...to_podcast.llm` | `string` | 用于生成播客稿件的 LLM 名称 (来自 `llms` 部分)。 | `llms` 部分中的默认 LLM | 否 |
| `...to_podcast.transcript_additional_prompt` | `string` | 附加到播客稿件生成 Prompt 的额外指令。 | | 否 |
| `...to_podcast.tts_llm` | `string` | 用于文本转语音 (TTS) 的 LLM 名称 (来自 `llms` 部分)。**注意:目前仅支持 `provider``gemini` 的 LLM**。 | `llms` 部分中的默认 LLM | 否 |
| `...to_podcast.speakers` | `对象列表` | 播客的演讲者列表。详见下方的 **演讲者配置**。 | `[]` | 是 |
#### 演讲者配置 (`...to_podcast.speakers[]`)
| 字段 | 类型 | 描述 | 默认值 | 是否必需 |
| :-------------------- | :------- | :------------------------ | :----- | :------- |
| `...speakers[].name` | `string` | 演讲者的名字。 | | 是 |
| `...speakers[].role` | `string` | 演讲者的角色描述 (人设)。 | | 否 |
| `...speakers[].voice` | `string` | 演讲者的声音。 | | 是 |
### 调度配置 (`scheduls`)
定义查询和监控 Feed 的规则。
@@ -173,10 +200,11 @@
定义*谁*接收通知。
| 字段 | 类型 | 描述 | 默认值 | 是否必需 |
| :------------------------- | :------- | :------------------------------- | :----- | :------------------ |
| `notify.receivers[].name` | `string` | 接收者的唯一名称。在路由中使用。 | | 是 |
| `notify.receivers[].email` | `string` | 接收者的电子邮件地址。 | | 是 (如果使用 Email) |
| 字段 | 类型 | 描述 | 默认值 | 是否必需 |
| :--------------------------- | :------- | :------------------------------------------------------- | :----- | :-------------------- |
| `notify.receivers[].name` | `string` | 接收者的唯一名称。在路由中使用。 | | 是 |
| `notify.receivers[].email` | `string` | 接收者的电子邮件地址。 | | 是 (如果使用 Email) |
| `notify.receivers[].webhook` | `object` | 接收者的 Webhook 配置。例如: `webhook: { "url": "xxx" }` | | 是 (如果使用 Webhook) |
### 通知渠道配置 (`notify.channels`)
@@ -194,4 +222,4 @@
| `...email.from` | `string` | 发件人 Email 地址。 | | 是 |
| `...email.password` | `string` | 发件人 Email 的应用专用密码。(对于 Gmail, 参见 [Google 应用密码](https://support.google.com/mail/answer/185833))。 | | 是 |
| `...email.feed_markdown_template` | `string` | 用于在 Email 正文中格式化每个 Feed 的 Markdown 模板。默认渲染 Feed 内容。不能与 `feed_html_snippet_template` 同时设置。可用的模板变量取决于 Feed 标签。 | `{{ .content }}` | 否 |
| `...email.feed_html_snippet_template` | `string` | 用于格式化每个 Feed 的 HTML 片段模板。不能与 `feed_markdown_template` 同时设置。可用的模板变量取决于 Feed 标签。 | | 否 |
| `...email.feed_html_snippet_template` | `string` | 用于格式化每个 Feed 的 HTML 片段模板。不能与 `feed_markdown_template` 同时设置。可用的模板变量取决于 Feed 标签。 | | 否 |

View File

@@ -41,6 +41,7 @@ This section defines the list of available Large Language Models. At least one L
| `llms[].api_key` | `string` | API key for the LLM. | | Yes |
| `llms[].model` | `string` | Model of the LLM. E.g., `gpt-4o-mini`. Cannot be empty if used for generation tasks (e.g., summarization). If this LLM is used, cannot be empty along with `embedding_model`. | | Conditionally Required |
| `llms[].embedding_model` | `string` | Embedding model of the LLM. E.g., `text-embedding-3-small`. Cannot be empty if used for embedding. If this LLM is used, cannot be empty along with `model`. **Note:** Do not modify directly after initial use; add a new LLM configuration instead. | | Conditionally Required |
| `llms[].tts_model` | `string` | The Text-to-Speech (TTS) model of the LLM. | | No |
| `llms[].temperature` | `float32` | Temperature of the LLM (0-2). | `0.0` | No |
### Jina AI Configuration (`jina`)
@@ -80,10 +81,11 @@ Describes each source to be scraped.
### Storage Configuration (`storage`)
| Field | Type | Description | Default Value | Required |
| :------------- | :------- | :------------------------------------------------------------------------------ | :-------------------- | :------- |
| `storage.dir` | `string` | Base directory for all storage. Cannot be changed after the application starts. | `./data` | No |
| `storage.feed` | `object` | Feed storage configuration. See **Feed Storage Configuration** below. | (See specific fields) | No |
| Field | Type | Description | Default Value | Required |
| :--------------- | :------- | :-------------------------------------------------------------------------------------------------------- | :-------------------- | :------- |
| `storage.dir` | `string` | Base directory for all storage. Cannot be changed after the application starts. | `./data` | No |
| `storage.feed` | `object` | Feed storage configuration. See **Feed Storage Configuration** below. | (See specific fields) | No |
| `storage.object` | `object` | Object storage configuration for storing files like podcasts. See **Object Storage Configuration** below. | (See specific fields) | No |
### Feed Storage Configuration (`storage.feed`)
@@ -95,6 +97,16 @@ Describes each source to be scraped.
| `storage.feed.retention` | `time.Duration` | Retention duration for feeds. | `8d` | No |
| `storage.feed.block_duration` | `time.Duration` | Retention duration for each time-based feed storage block (similar to Prometheus TSDB Block). | `25h` | No |
### Object Storage Configuration (`storage.object`)
| Field | Type | Description | Default Value | Required |
| :--------------------------------- | :------- | :------------------------------------------- | :------------ | :----------------------------- |
| `storage.object.endpoint` | `string` | The endpoint of the object storage. | | Yes (if using podcast feature) |
| `storage.object.access_key_id` | `string` | The access key id of the object storage. | | Yes (if using podcast feature) |
| `storage.object.secret_access_key` | `string` | The secret access key of the object storage. | | Yes (if using podcast feature) |
| `storage.object.bucket` | `string` | The bucket of the object storage. | | Yes (if using podcast feature) |
| `storage.object.bucket` | `string` | The URL of the object storage bucket. | | No |
### Rewrite Rule Configuration (`storage.feed.rewrites[]`)
Defines rules to process feeds before storage. Rules are applied sequentially.
@@ -109,12 +121,8 @@ Defines rules to process feeds before storage. Rules are applied sequentially.
| `...rewrites[].match_re` | `string` | Regular expression to match against the (transformed) text. | `.*` (matches all) | No (use `match` or `match_re`) |
| `...rewrites[].action` | `string` | Action to perform on match: `create_or_update_label` (adds/updates a label with the matched/transformed text), `drop_feed` (discards the feed entirely). | `create_or_update_label` | No |
| `...rewrites[].label` | `string` | Name of the feed label to create or update. | | Yes (if `action` is `create_or_update_label`) |
### Rewrite Rule Transform Configuration (`storage.feed.rewrites[].transform`)
| Field | Type | Description | Default Value | Required |
| :--------------------- | :------- | :--------------------------------------------------------------------------------------------- | :------------ | :------- |
| `...transform.to_text` | `object` | Transforms source text to text using an LLM. See **Rewrite Rule To Text Configuration** below. | `nil` | No |
| `...transform.to_text` | `object` | Transforms source text to text using an LLM. See **Rewrite Rule To Text Configuration** below. | `nil` | No |
| `...transform.to_podcast` | `object` | Transforms source text to a podcast. See **Rewrite Rule To Podcast Configuration** below. | `nil` | No |
### Rewrite Rule To Text Configuration (`storage.feed.rewrites[].transform.to_text`)
@@ -124,7 +132,26 @@ This configuration defines how to transform the text from `source_label`.
| :------------------ | :------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------------------------- | :-------------------------- |
| `...to_text.type` | `string` | Type of transformation. Options: <ul><li>`prompt` (default): Uses an LLM and a specified prompt to transform the source text.</li><li>`crawl`: Treats the source text as a URL, directly crawls the web page content pointed to by the URL, and converts it to Markdown format. This method performs local crawling and attempts to follow `robots.txt`.</li><li>`crawl_by_jina`: Treats the source text as a URL, crawls and processes web page content via the [Jina AI Reader API](https://jina.ai/reader/), and returns Markdown. Potentially more powerful, e.g., for handling dynamic pages, but relies on the Jina AI service.</li></ul> | `prompt` | No |
| `...to_text.llm` | `string` | **Only valid if `type` is `prompt`.** Name of the LLM used for transformation (from `llms` section). If not specified, the LLM marked as `default: true` in the `llms` section will be used. | Default LLM in `llms` section | No |
| `...to_text.prompt` | `string` | **Only valid if `type` is `prompt`.** Prompt used for transformation. The source text will be injected. You can use Go template syntax to reference built-in prompts: `{{ .summary }}`, `{{ .category }}`, `{{ .tags }}`, `{{ .score }}`, `{{ .comment_confucius }}`, `{{ .summary_html_snippet }}`. | | Yes (if `type` is `prompt`) |
| `...to_text.prompt` | `string` | **Only valid if `type` is `prompt`.** Prompt used for transformation. The source text will be injected. You can use Go template syntax to reference built-in prompts: `{{ .summary }}`, `{{ .category }}`, `{{ .tags }}`, `{{ .score }}`, `{{ .comment_confucius }}`, `{{ .summary_html_snippet }}`, `{{ .summary_html_snippet_for_small_model }}`. | | Yes (if `type` is `prompt`) |
### Rewrite Rule To Podcast Configuration (`storage.feed.rewrites[].transform.to_podcast`)
This configuration defines how to transform the text from `source_label` into a podcast.
| Field | Type | Description | Default Value | Required |
| :------------------------------------------- | :---------------- | :--------------------------------------------------------------------------------------------------------------------------------------------- | :---------------------------- | :------- |
| `...to_podcast.llm` | `string` | The name of the LLM (from the `llms` section) to use for generating the podcast script. | Default LLM in `llms` section | No |
| `...to_podcast.transcript_additional_prompt` | `string` | Additional instructions to append to the prompt for generating the podcast script. | | No |
| `...to_podcast.tts_llm` | `string` | The name of the LLM (from the `llms` section) to use for Text-to-Speech (TTS). **Note: Currently only supports LLMs with `provider: gemini`**. | Default LLM in `llms` section | No |
| `...to_podcast.speakers` | `list of objects` | A list of speakers for the podcast. See **Speaker Configuration** below. | `[]` | Yes |
#### Speaker Configuration (`...to_podcast.speakers[]`)
| Field | Type | Description | Default Value | Required |
| :-------------------- | :------- | :----------------------------------- | :------------ | :------- |
| `...speakers[].name` | `string` | The name of the speaker. | | Yes |
| `...speakers[].role` | `string` | The role description of the speaker. | | No |
| `...speakers[].voice` | `string` | The voice of the speaker. | | Yes |
### Scheduling Configuration (`scheduls`)
@@ -173,10 +200,11 @@ This structure can be nested using `sub_routes`. Feeds will first try to match s
Defines *who* receives notifications.
| Field | Type | Description | Default Value | Required |
| :------------------------- | :------- | :------------------------------------------- | :------------ | :------------------- |
| `notify.receivers[].name` | `string` | Unique name of the receiver. Used in routes. | | Yes |
| `notify.receivers[].email` | `string` | Email address of the receiver. | | Yes (if using Email) |
| Field | Type | Description | Default Value | Required |
| :--------------------------- | :------- | :----------------------------------------------------------------------- | :------------ | :--------------------- |
| `notify.receivers[].name` | `string` | Unique name of the receiver. Used in routes. | | Yes |
| `notify.receivers[].email` | `string` | Email address of the receiver. | | Yes (if using Email) |
| `notify.receivers[].webhook` | `object` | Webhook configuration for the receiver. E.g. `webhook: { "url": "xxx" }` | | Yes (if using Webhook) |
### Notification Channel Configuration (`notify.channels`)

104
docs/podcast.md Normal file
View File

@@ -0,0 +1,104 @@
# 使用 Zenfeed 将文章转换为播客
Zenfeed 的播客功能可以将任何文章源自动转换为一场引人入胜的多人对话式播客。该功能利用大语言模型LLM生成对话脚本和文本转语音TTS并将最终的音频文件托管在您自己的对象存储中。
## 工作原理
1. **提取内容**: Zenfeed 首先通过重写规则提取文章的全文内容。
2. **生成脚本**: 使用一个指定的 LLM如 GPT-4o-mini将文章内容改编成一个由多位虚拟主播对话的脚本。您可以定义每个主播的角色人设来控制对话风格。
3. **语音合成**: 调用另一个支持 TTS 的 LLM目前仅支持 Google Gemini将脚本中的每一句对话转换为语音。
4. **音频合并**: 将所有语音片段合成为一个完整的 WAV 音频文件。
5. **上传存储**: 将生成的播客文件上传到您配置的 S3 兼容对象存储中。
6. **保存链接**: 最后,将播客文件的公开访问 URL 保存为一个新的 Feed 标签方便您在通知、API 或其他地方使用。
## 配置步骤
要启用播客功能您需要完成以下三项配置LLM、对象存储和重写规则。
### 1. 配置 LLM
您需要至少配置两个 LLM一个用于生成对话脚本另一个用于文本转语音TTS
- **脚本生成 LLM**: 可以是任何性能较好的聊天模型,例如 OpenAI 的 `gpt-4o-mini` 或 Google 的 `gemini-1.5-pro`
- **TTS LLM**: 用于将文本转换为语音。**注意:目前此功能仅支持 `provider``gemini` 的 LLM。**
**示例 `config.yaml`:**
```yaml
llms:
# 用于生成播客脚本的 LLM
- name: openai-chat
provider: openai
api_key: "sk-..."
model: gpt-4o-mini
default: true
# 用于文本转语音 (TTS) 的 LLM
- name: gemini-tts
provider: gemini
api_key: "..." # 你的 Google AI Studio API Key
tts_model: "gemini-2.5-flash-preview-tts" # Gemini 的 TTS 模型
```
### 2. 配置对象存储
生成的播客音频文件需要一个地方存放。Zenfeed 支持任何 S3 兼容的对象存储服务。这里我们以 [Cloudflare R2](https://www.cloudflare.com/zh-cn/products/r2/) 为例。
首先,您需要在 Cloudflare R2 中创建一个存储桶Bucket。然后获取以下信息
- `endpoint`: 您的 R2 API 端点。通常格式为 `<account_id>.r2.cloudflarestorage.com`。您可以在 R2 存储桶的主页找到它。
- `access_key_id``secret_access_key`: R2 API 令牌。您可以在 "R2" -> "管理 R2 API 令牌" 页面创建。
- `bucket`: 您创建的存储桶的名称。
- `bucket_url`: 存储桶的公开访问 URL。要获取此 URL您需要将存储桶连接到一个自定义域或者使用 R2 提供的 `r2.dev` 公开访问地址。
**示例 `config.yaml`:**
```yaml
storage:
object:
endpoint: "<your_account_id>.r2.cloudflarestorage.com"
access_key_id: "..."
secret_access_key: "..."
bucket: "zenfeed-podcasts"
bucket_url: "https://pub-xxxxxxxx.r2.dev"
```
### 3. 配置重写规则
最后一步是创建一个重写规则,告诉 Zenfeed 如何将文章转换为播客。这个规则定义了使用哪个标签作为源文本、由谁来对话、使用什么声音等。
**关键配置项:**
- `source_label`: 包含文章全文的标签。
- `label`: 用于存储最终播客 URL 的新标签名称。
- `transform.to_podcast`: 播客转换的核心配置。
- `llm`: 用于生成脚本的 LLM 名称(来自 `llms` 配置)。
- `tts_llm`: 用于 TTS 的 LLM 名称(来自 `llms` 配置)。
- `speakers`: 定义播客的演讲者。
- `name`: 演讲者的名字。
- `role`: 演讲者的角色和人设,将影响脚本内容。
- `voice`: 演讲者的声音。请参考 [Google Cloud TTS 文档](https://cloud.google.com/text-to-speech/docs/voices) 获取可用的声音名称(例如 `en-US-Standard-C``en-US-News-N`)。
**示例 `config.yaml`:**
```yaml
storage:
feed:
rewrites:
- source_label: "content"
label: "podcast_url"
transform:
to_podcast:
llm: "openai-chat"
tts_llm: "gemini-tts"
transcript_additional_prompt: "使用中文回复"
speakers:
- name: "主持人小雅"
role: "一位经验丰富、声音甜美、风格活泼的科技播客主持人。擅长联系实际生活场景。"
voice: "zh-CN-Standard-A" # 女声
- name: "技术评论员老王"
role: "一位对技术有深入见解、观点犀利的评论员,说话直接,偶尔有些愤世嫉俗。"
voice: "zh-CN-Standard-B" # 男声
```
配置完成后Zenfeed 将在每次抓取到新文章时,自动执行上述流程。可以在通知模版中使用 podcast_url label或 Web 中直接收听Web 固定读取 podcast_url label若使用别的名称则无法读取

13
go.mod
View File

@@ -9,6 +9,7 @@ require (
github.com/edsrzf/mmap-go v1.2.0
github.com/gorilla/feeds v1.2.0
github.com/mark3labs/mcp-go v0.17.0
github.com/minio/minio-go/v7 v7.0.94
github.com/mmcdole/gofeed v1.3.0
github.com/nutsdb/nutsdb v1.0.4
github.com/onsi/gomega v1.36.1
@@ -16,6 +17,7 @@ require (
github.com/prometheus/client_golang v1.21.1
github.com/sashabaranov/go-openai v1.40.1
github.com/stretchr/testify v1.10.0
github.com/temoto/robotstxt v1.1.2
github.com/veqryn/slog-dedup v0.5.0
github.com/yuin/goldmark v1.7.8
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df
@@ -32,25 +34,34 @@ require (
github.com/bwmarrin/snowflake v0.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/go-ini/ini v1.67.0 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/gofrs/flock v0.8.1 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/minio/crc64nvme v1.0.1 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/mmcdole/goxpp v1.1.1-0.20240225020742-a0c311522b23 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/rs/xid v1.6.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/temoto/robotstxt v1.1.2
github.com/tidwall/btree v1.6.0 // indirect
github.com/tinylib/msgp v1.3.0 // indirect
github.com/xujiajun/mmap-go v1.0.1 // indirect
github.com/xujiajun/utils v0.0.0-20220904132955-5f7c5b914235 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect

23
go.sum
View File

@@ -21,12 +21,18 @@ github.com/chewxy/math32 v1.10.1/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUw
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/edsrzf/mmap-go v1.2.0 h1:hXLYlkbaPzt1SaQk+anYwKSRNhufIDCchSPkUD6dD84=
github.com/edsrzf/mmap-go v1.2.0/go.mod h1:19H/e8pUPLicwkyNgOykDXkJ9F0MHE+Z52B8EIth78Q=
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw=
github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@@ -42,6 +48,9 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
@@ -53,6 +62,12 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mark3labs/mcp-go v0.17.0 h1:5Ps6T7qXr7De/2QTqs9h6BKeZ/qdeUeGrgM5lPzi930=
github.com/mark3labs/mcp-go v0.17.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE=
github.com/minio/crc64nvme v1.0.1 h1:DHQPrYPdqK7jQG/Ls5CTBZWeex/2FMS3G5XGkycuFrY=
github.com/minio/crc64nvme v1.0.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.94 h1:1ZoksIKPyaSt64AVOyaQvhDOgVC3MfZsWM6mZXRUGtM=
github.com/minio/minio-go/v7 v7.0.94/go.mod h1:71t2CqDt3ThzESgZUlU1rBN54mksGGlkLcFgguDnnAc=
github.com/mmcdole/gofeed v1.3.0 h1:5yn+HeqlcvjMeAI4gu6T+crm7d0anY85+M+v6fIFNG4=
github.com/mmcdole/gofeed v1.3.0/go.mod h1:9TGv2LcJhdXePDzxiuMnukhV2/zb6VtnZt1mS+SjkLE=
github.com/mmcdole/goxpp v1.1.1-0.20240225020742-a0c311522b23 h1:Zr92CAlFhy2gL+V1F+EyIuzbQNbSgP4xhTODZtrXUtk=
@@ -70,6 +85,8 @@ github.com/onsi/ginkgo/v2 v2.20.1 h1:YlVIbqct+ZmnEph770q9Q7NVAz4wwIiVNahee6JyUzo
github.com/onsi/ginkgo/v2 v2.20.1/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI=
github.com/onsi/gomega v1.36.1 h1:bJDPBO7ibjxcbHMgSCoo4Yj18UWbKDlLwX1x9sybDcw=
github.com/onsi/gomega v1.36.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog=
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1GshSTtih8C2gDs04w8dReiOGXrGLNoY=
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -87,6 +104,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6O
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/sashabaranov/go-openai v1.40.1 h1:bJ08Iwct5mHBVkuvG6FEcb9MDTfsXdTYPGjYLRdeTEU=
github.com/sashabaranov/go-openai v1.40.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y=
@@ -106,6 +125,8 @@ github.com/temoto/robotstxt v1.1.2 h1:W2pOjSJ6SWvldyEuiFXNxz3xZ8aiWX5LbfDiOFd7Fx
github.com/temoto/robotstxt v1.1.2/go.mod h1:+1AmkuG3IYkh1kv0d2qEB9Le88ehNO0zwOr3ujewlOo=
github.com/tidwall/btree v1.6.0 h1:LDZfKfQIBHGHWSwckhXI0RPSXzlo+KYdjK7FWSqOzzg=
github.com/tidwall/btree v1.6.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY=
github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww=
github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0=
github.com/veqryn/slog-dedup v0.5.0 h1:2pc4va3q8p7Tor1SjVvi1ZbVK/oKNPgsqG15XFEt0iM=
github.com/veqryn/slog-dedup v0.5.0/go.mod h1:/iQU008M3qFa5RovtfiHiODxJFvxZLjWRG/qf/zKFHw=
github.com/xujiajun/mmap-go v1.0.1 h1:7Se7ss1fLPPRW+ePgqGpCkfGIZzJV6JPq9Wq9iv/WHc=
@@ -123,6 +144,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=

46
main.go
View File

@@ -47,6 +47,7 @@ import (
"github.com/glidea/zenfeed/pkg/storage/feed/block/index/primary"
"github.com/glidea/zenfeed/pkg/storage/feed/block/index/vector"
"github.com/glidea/zenfeed/pkg/storage/kv"
"github.com/glidea/zenfeed/pkg/storage/object"
"github.com/glidea/zenfeed/pkg/telemetry/log"
telemetryserver "github.com/glidea/zenfeed/pkg/telemetry/server"
timeutil "github.com/glidea/zenfeed/pkg/util/time"
@@ -122,18 +123,19 @@ type App struct {
conf *config.App
telemetry telemetryserver.Server
kvStorage kv.Storage
llmFactory llm.Factory
rewriter rewrite.Rewriter
feedStorage feed.Storage
api api.API
http http.Server
mcp mcp.Server
rss rss.Server
scraperMgr scrape.Manager
scheduler schedule.Scheduler
notifier notify.Notifier
notifyChan chan *rule.Result
kvStorage kv.Storage
llmFactory llm.Factory
rewriter rewrite.Rewriter
feedStorage feed.Storage
objectStorage object.Storage
api api.API
http http.Server
mcp mcp.Server
rss rss.Server
scraperMgr scrape.Manager
scheduler schedule.Scheduler
notifier notify.Notifier
notifyChan chan *rule.Result
}
// newApp creates a new application instance.
@@ -164,6 +166,9 @@ func (a *App) setup() error {
if err := a.setupKVStorage(); err != nil {
return errors.Wrap(err, "setup kv storage")
}
if err := a.setupObjectStorage(); err != nil {
return errors.Wrap(err, "setup object storage")
}
if err := a.setupLLMFactory(); err != nil {
return errors.Wrap(err, "setup llm factory")
}
@@ -251,7 +256,8 @@ func (a *App) setupLLMFactory() (err error) {
// setupRewriter initializes the Rewriter factory.
func (a *App) setupRewriter() (err error) {
a.rewriter, err = rewrite.NewFactory().New(component.Global, a.conf, rewrite.Dependencies{
LLMFactory: a.llmFactory,
LLMFactory: a.llmFactory,
ObjectStorage: a.objectStorage,
})
if err != nil {
return err
@@ -282,6 +288,18 @@ func (a *App) setupFeedStorage() (err error) {
return nil
}
// setupObjectStorage initializes the Object storage.
func (a *App) setupObjectStorage() (err error) {
a.objectStorage, err = object.NewFactory().New(component.Global, a.conf, object.Dependencies{})
if err != nil {
return err
}
a.configMgr.Subscribe(a.objectStorage)
return nil
}
// setupTelemetryServer initializes the Telemetry server.
func (a *App) setupTelemetryServer() (err error) {
a.telemetry, err = telemetryserver.NewFactory().New(component.Global, a.conf, telemetryserver.Dependencies{})
@@ -419,7 +437,7 @@ func (a *App) run(ctx context.Context) error {
log.Info(ctx, "starting application components...")
if err := component.Run(ctx,
component.Group{a.configMgr},
component.Group{a.llmFactory, a.telemetry},
component.Group{a.llmFactory, a.objectStorage, a.telemetry},
component.Group{a.rewriter},
component.Group{a.feedStorage},
component.Group{a.kvStorage},

View File

@@ -90,6 +90,7 @@ type LLM struct {
APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty" desc:"The API key of the LLM. It is required when api.llm is set."`
Model string `yaml:"model,omitempty" json:"model,omitempty" desc:"The model of the LLM. e.g. gpt-4o-mini. Can not be empty with embedding_model at same time when api.llm is set."`
EmbeddingModel string `yaml:"embedding_model,omitempty" json:"embedding_model,omitempty" desc:"The embedding model of the LLM. e.g. text-embedding-3-small. Can not be empty with model at same time when api.llm is set. NOTE: Once used, do not modify it directly, instead, add a new LLM configuration."`
TTSModel string `yaml:"tts_model,omitempty" json:"tts_model,omitempty" desc:"The TTS model of the LLM."`
Temperature float32 `yaml:"temperature,omitempty" json:"temperature,omitempty" desc:"The temperature (0-2) of the LLM. Default: 0.0"`
}
@@ -101,8 +102,9 @@ type Scrape struct {
}
type Storage struct {
Dir string `yaml:"dir,omitempty" json:"dir,omitempty" desc:"The base directory of the all storages. Default: ./data. It can not be changed after the app is running."`
Feed FeedStorage `yaml:"feed,omitempty" json:"feed,omitempty" desc:"The feed storage config."`
Dir string `yaml:"dir,omitempty" json:"dir,omitempty" desc:"The base directory of the all storages. Default: ./data. It can not be changed after the app is running."`
Feed FeedStorage `yaml:"feed,omitempty" json:"feed,omitempty" desc:"The feed storage config."`
Object ObjectStorage `yaml:"object,omitempty" json:"object,omitempty" desc:"The object storage config."`
}
type FeedStorage struct {
@@ -113,6 +115,14 @@ type FeedStorage struct {
BlockDuration timeutil.Duration `yaml:"block_duration,omitempty" json:"block_duration,omitempty" desc:"How long to keep the feed storage block. Block is time-based, like Prometheus TSDB Block. Default: 25h"`
}
type ObjectStorage struct {
Endpoint string `yaml:"endpoint,omitempty" json:"endpoint,omitempty" desc:"The endpoint of the object storage."`
AccessKeyID string `yaml:"access_key_id,omitempty" json:"access_key_id,omitempty" desc:"The access key id of the object storage."`
SecretAccessKey string `yaml:"secret_access_key,omitempty" json:"secret_access_key,omitempty" desc:"The secret access key of the object storage."`
Bucket string `yaml:"bucket,omitempty" json:"bucket,omitempty" desc:"The bucket of the object storage."`
BucketURL string `yaml:"bucket_url,omitempty" json:"bucket_url,omitempty" desc:"The public URL of the object storage bucket."`
}
type ScrapeSource struct {
Interval timeutil.Duration `yaml:"interval,omitempty" json:"interval,omitempty" desc:"How often to scrape this source. Default: global interval"`
Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the source. It is required."`
@@ -137,7 +147,8 @@ type RewriteRule struct {
}
type RewriteRuleTransform struct {
ToText *RewriteRuleTransformToText `yaml:"to_text,omitempty" json:"to_text,omitempty" desc:"The transform config to transform the source text to text."`
ToText *RewriteRuleTransformToText `yaml:"to_text,omitempty" json:"to_text,omitempty" desc:"The transform config to transform the source text to text."`
ToPodcast *RewriteRuleTransformToPodcast `yaml:"to_podcast,omitempty" json:"to_podcast,omitempty" desc:"The transform config to transform the source text to podcast."`
}
type RewriteRuleTransformToText struct {
@@ -146,6 +157,20 @@ type RewriteRuleTransformToText struct {
Prompt string `yaml:"prompt,omitempty" json:"prompt,omitempty" desc:"The prompt to transform the source text. The source text will be injected into the prompt above. And you can use go template syntax to refer some built-in prompts, like {{ .summary }}. Available built-in prompts: category, tags, score, comment_confucius, summary, summary_html_snippet."`
}
type RewriteRuleTransformToPodcast struct {
LLM string `yaml:"llm,omitempty" json:"llm,omitempty" desc:"The LLM name to use. Default is the default LLM in llms section."`
EstimateMaximumDuration timeutil.Duration `yaml:"estimate_maximum_duration,omitempty" json:"estimate_maximum_duration,omitempty" desc:"The estimated maximum duration of the podcast. It will affect the length of the generated transcript. e.g. 5m. Default is 5m."`
TranscriptAdditionalPrompt string `yaml:"transcript_additional_prompt,omitempty" json:"transcript_additional_prompt,omitempty" desc:"The additional prompt to add to the transcript. It is optional."`
TTSLLM string `yaml:"tts_llm,omitempty" json:"tts_llm,omitempty" desc:"The LLM name to use for TTS. Only supports gemini now. Default is the default LLM in llms section."`
Speakers []RewriteRuleTransformToPodcastSpeaker `yaml:"speakers,omitempty" json:"speakers,omitempty" desc:"The speakers to use. It is required, at least one speaker is needed."`
}
type RewriteRuleTransformToPodcastSpeaker struct {
Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the speaker. It is required."`
Role string `yaml:"role,omitempty" json:"role,omitempty" desc:"The role description of the speaker. You can think of it as a character setting."`
Voice string `yaml:"voice,omitempty" json:"voice,omitempty" desc:"The voice of the speaker. It is required."`
}
type SchedulsRule struct {
Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the rule. It is required."`
Query string `yaml:"query,omitempty" json:"query,omitempty" desc:"The semantic query to get the feeds. NOTE it is optional"`

248
pkg/llm/gemini.go Normal file
View File

@@ -0,0 +1,248 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package llm
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"path/filepath"
"github.com/pkg/errors"
oai "github.com/sashabaranov/go-openai"
"github.com/glidea/zenfeed/pkg/component"
"github.com/glidea/zenfeed/pkg/telemetry"
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
"github.com/glidea/zenfeed/pkg/util/wav"
)
type gemini struct {
*component.Base[Config, struct{}]
text
hc *http.Client
embeddingSpliter embeddingSpliter
}
func newGemini(c *Config) LLM {
config := oai.DefaultConfig(c.APIKey)
config.BaseURL = filepath.Join(c.Endpoint, "openai") // OpenAI compatible endpoint.
client := oai.NewClientWithConfig(config)
embeddingSpliter := newEmbeddingSpliter(1536, 64)
base := component.New(&component.BaseConfig[Config, struct{}]{
Name: "LLM/gemini",
Instance: c.Name,
Config: c,
})
return &gemini{
Base: base,
text: &openaiText{
Base: base,
client: client,
},
hc: &http.Client{},
embeddingSpliter: embeddingSpliter,
}
}
func (g *gemini) WAV(ctx context.Context, text string, speakers []Speaker) (r io.ReadCloser, err error) {
ctx = telemetry.StartWith(ctx, append(g.TelemetryLabels(), telemetrymodel.KeyOperation, "WAV")...)
defer func() { telemetry.End(ctx, err) }()
if g.Config().TTSModel == "" {
return nil, errors.New("tts model is not set")
}
reqPayload, err := buildWAVRequestPayload(text, speakers)
if err != nil {
return nil, errors.Wrap(err, "build wav request payload")
}
pcmData, err := g.doWAVRequest(ctx, reqPayload)
if err != nil {
return nil, errors.Wrap(err, "do wav request")
}
return streamWAV(pcmData), nil
}
func (g *gemini) doWAVRequest(ctx context.Context, reqPayload *geminiRequest) ([]byte, error) {
config := g.Config()
body, err := json.Marshal(reqPayload)
if err != nil {
return nil, errors.Wrap(err, "marshal tts request")
}
url := config.Endpoint + "/models/" + config.TTSModel + ":generateContent"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "new tts request")
}
req.Header.Set("x-goog-api-key", config.APIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := g.hc.Do(req)
if err != nil {
return nil, errors.Wrap(err, "do tts request")
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
errMsg, _ := io.ReadAll(resp.Body)
return nil, errors.Errorf("tts request failed with status %d: %s", resp.StatusCode, string(errMsg))
}
var ttsResp geminiResponse
if err := json.NewDecoder(resp.Body).Decode(&ttsResp); err != nil {
return nil, errors.Wrap(err, "decode tts response")
}
if len(ttsResp.Candidates) == 0 || len(ttsResp.Candidates[0].Content.Parts) == 0 || ttsResp.Candidates[0].Content.Parts[0].InlineData == nil {
return nil, errors.New("no audio data in tts response")
}
audioDataB64 := ttsResp.Candidates[0].Content.Parts[0].InlineData.Data
pcmData, err := base64.StdEncoding.DecodeString(audioDataB64)
if err != nil {
return nil, errors.Wrap(err, "decode base64")
}
return pcmData, nil
}
func buildWAVRequestPayload(text string, speakers []Speaker) (*geminiRequest, error) {
reqPayload := geminiRequest{
Contents: []*geminiRequestContent{{Parts: []*geminiRequestPart{{Text: text}}}},
Config: &geminiRequestConfig{
ResponseModalities: []string{"AUDIO"},
SpeechConfig: &geminiRequestSpeechConfig{},
},
}
switch len(speakers) {
case 0:
return nil, errors.New("no speakers")
case 1:
reqPayload.Config.SpeechConfig.VoiceConfig = &geminiRequestVoiceConfig{
PrebuiltVoiceConfig: &geminiRequestPrebuiltVoiceConfig{VoiceName: speakers[0].Voice},
}
default:
multiSpeakerConfig := &geminiRequestMultiSpeakerVoiceConfig{}
for _, s := range speakers {
multiSpeakerConfig.SpeakerVoiceConfigs = append(multiSpeakerConfig.SpeakerVoiceConfigs, &geminiRequestSpeakerVoiceConfig{
Speaker: s.Name,
VoiceConfig: &geminiRequestVoiceConfig{
PrebuiltVoiceConfig: &geminiRequestPrebuiltVoiceConfig{VoiceName: s.Voice},
},
})
}
reqPayload.Config.SpeechConfig.MultiSpeakerVoiceConfig = multiSpeakerConfig
}
return &reqPayload, nil
}
func streamWAV(pcmData []byte) io.ReadCloser {
pipeReader, pipeWriter := io.Pipe()
go func() {
defer func() { _ = pipeWriter.Close() }()
if err := wav.WriteHeader(pipeWriter, geminiWavHeader, uint32(len(pcmData))); err != nil {
pipeWriter.CloseWithError(errors.Wrap(err, "write wav header"))
return
}
if _, err := io.Copy(pipeWriter, bytes.NewReader(pcmData)); err != nil {
pipeWriter.CloseWithError(errors.Wrap(err, "write pcm data"))
return
}
}()
return pipeReader
}
var geminiWavHeader = &wav.Header{
SampleRate: 24000,
BitDepth: 16,
NumChannels: 1,
}
type geminiRequest struct {
Contents []*geminiRequestContent `json:"contents"`
Config *geminiRequestConfig `json:"generationConfig"`
}
type geminiRequestContent struct {
Parts []*geminiRequestPart `json:"parts"`
}
type geminiRequestPart struct {
Text string `json:"text"`
}
type geminiRequestConfig struct {
ResponseModalities []string `json:"responseModalities"`
SpeechConfig *geminiRequestSpeechConfig `json:"speechConfig"`
}
type geminiRequestSpeechConfig struct {
VoiceConfig *geminiRequestVoiceConfig `json:"voiceConfig,omitempty"`
MultiSpeakerVoiceConfig *geminiRequestMultiSpeakerVoiceConfig `json:"multiSpeakerVoiceConfig,omitempty"`
}
type geminiRequestVoiceConfig struct {
PrebuiltVoiceConfig *geminiRequestPrebuiltVoiceConfig `json:"prebuiltVoiceConfig,omitempty"`
}
type geminiRequestPrebuiltVoiceConfig struct {
VoiceName string `json:"voiceName,omitempty"`
}
type geminiRequestMultiSpeakerVoiceConfig struct {
SpeakerVoiceConfigs []*geminiRequestSpeakerVoiceConfig `json:"speakerVoiceConfigs,omitempty"`
}
type geminiRequestSpeakerVoiceConfig struct {
Speaker string `json:"speaker,omitempty"`
VoiceConfig *geminiRequestVoiceConfig `json:"voiceConfig,omitempty"`
}
type geminiResponse struct {
Candidates []*geminiResponseCandidate `json:"candidates"`
}
type geminiResponseCandidate struct {
Content *geminiResponseContent `json:"content"`
}
type geminiResponseContent struct {
Parts []*geminiResponsePart `json:"parts"`
}
type geminiResponsePart struct {
InlineData *geminiResponseInlineData `json:"inlineData"`
}
type geminiResponseInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"` // Base64 encoded.
}

View File

@@ -18,6 +18,7 @@ package llm
import (
"bytes"
"context"
"io"
"reflect"
"strconv"
"sync"
@@ -42,19 +43,33 @@ import (
// --- Interface code block ---
type LLM interface {
component.Component
text
audio
}
type text interface {
String(ctx context.Context, messages []string) (string, error)
EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error)
Embedding(ctx context.Context, text string) ([]float32, error)
}
type audio interface {
WAV(ctx context.Context, text string, speakers []Speaker) (io.ReadCloser, error)
}
type Speaker struct {
Name string
Voice string
}
type Config struct {
Name string
Default bool
Provider ProviderType
Endpoint string
APIKey string
Model, EmbeddingModel string
Temperature float32
Name string
Default bool
Provider ProviderType
Endpoint string
APIKey string
Model, EmbeddingModel, TTSModel string
Temperature float32
}
type ProviderType string
@@ -72,7 +87,7 @@ var defaultEndpoints = map[ProviderType]string{
ProviderTypeOpenAI: "https://api.openai.com/v1",
ProviderTypeOpenRouter: "https://openrouter.ai/api/v1",
ProviderTypeDeepSeek: "https://api.deepseek.com/v1",
ProviderTypeGemini: "https://generativelanguage.googleapis.com/v1beta/openai",
ProviderTypeGemini: "https://generativelanguage.googleapis.com/v1beta",
ProviderTypeVolc: "https://ark.cn-beijing.volces.com/api/v3",
ProviderTypeSiliconFlow: "https://api.siliconflow.cn/v1",
}
@@ -97,8 +112,8 @@ func (c *Config) Validate() error { //nolint:cyclop
if c.APIKey == "" {
return errors.New("api key is required")
}
if c.Model == "" && c.EmbeddingModel == "" {
return errors.New("model or embedding model is required")
if c.Model == "" && c.EmbeddingModel == "" && c.TTSModel == "" {
return errors.New("model or embedding model or tts model is required")
}
if c.Temperature < 0 || c.Temperature > 2 {
return errors.Errorf("invalid temperature: %f, should be in range [0, 2]", c.Temperature)
@@ -182,6 +197,7 @@ func (c *FactoryConfig) From(app *config.App) {
APIKey: llm.APIKey,
Model: llm.Model,
EmbeddingModel: llm.EmbeddingModel,
TTSModel: llm.TTSModel,
Temperature: llm.Temperature,
})
}
@@ -207,12 +223,9 @@ func NewFactory(
) (Factory, error) {
if len(mockOn) > 0 {
mf := &mockFactory{}
getCall := mf.On("Get", mock.Anything)
getCall.Run(func(args mock.Arguments) {
m := &mockLLM{}
component.MockOptions(mockOn).Apply(&m.Mock)
getCall.Return(m, nil)
})
m := &mockLLM{}
component.MockOptions(mockOn).Apply(&m.Mock)
mf.On("Get", mock.Anything).Return(m)
mf.On("Reload", mock.Anything).Return(nil)
return mf, nil
@@ -307,11 +320,6 @@ func (f *factory) Get(name string) LLM {
continue
}
if f.llms[name] == nil {
llm := f.new(&llmC)
f.llms[name] = llm
}
return f.llms[name]
}
@@ -320,8 +328,12 @@ func (f *factory) Get(name string) LLM {
func (f *factory) new(c *Config) LLM {
switch c.Provider {
case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek, ProviderTypeGemini, ProviderTypeVolc, ProviderTypeSiliconFlow: //nolint:lll
case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek, ProviderTypeVolc, ProviderTypeSiliconFlow: //nolint:lll
return newCached(newOpenAI(c), f.Dependencies().KVStorage)
case ProviderTypeGemini:
return newCached(newGemini(c), f.Dependencies().KVStorage)
default:
return newCached(newOpenAI(c), f.Dependencies().KVStorage)
}
@@ -333,14 +345,17 @@ func (f *factory) initLLMs() {
llms = make(map[string]LLM, len(config.LLMs))
defaultLLM LLM
)
for _, llmC := range config.LLMs {
llm := f.new(&llmC)
llms[llmC.Name] = llm
if llmC.Name == config.defaultLLM {
defaultLLM = llm
}
}
f.llms = llms
f.defaultLLM = defaultLLM
}
@@ -482,12 +497,27 @@ func (m *mockLLM) String(ctx context.Context, messages []string) (string, error)
func (m *mockLLM) EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error) {
args := m.Called(ctx, labels)
if args.Error(1) != nil {
return nil, args.Error(1)
}
return args.Get(0).([][]float32), args.Error(1)
}
func (m *mockLLM) Embedding(ctx context.Context, text string) ([]float32, error) {
args := m.Called(ctx, text)
if args.Error(1) != nil {
return nil, args.Error(1)
}
return args.Get(0).([]float32), args.Error(1)
}
func (m *mockLLM) WAV(ctx context.Context, text string, speakers []Speaker) (io.ReadCloser, error) {
args := m.Called(ctx, text, speakers)
if args.Error(1) != nil {
return nil, args.Error(1)
}
return args.Get(0).(io.ReadCloser), args.Error(1)
}

View File

@@ -18,6 +18,7 @@ package llm
import (
"context"
"encoding/json"
"io"
"github.com/pkg/errors"
oai "github.com/sashabaranov/go-openai"
@@ -31,9 +32,7 @@ import (
type openai struct {
*component.Base[Config, struct{}]
client *oai.Client
embeddingSpliter embeddingSpliter
text
}
func newOpenAI(c *Config) LLM {
@@ -42,18 +41,34 @@ func newOpenAI(c *Config) LLM {
client := oai.NewClientWithConfig(config)
embeddingSpliter := newEmbeddingSpliter(1536, 64)
base := component.New(&component.BaseConfig[Config, struct{}]{
Name: "LLM/openai",
Instance: c.Name,
Config: c,
})
return &openai{
Base: component.New(&component.BaseConfig[Config, struct{}]{
Name: "LLM/openai",
Instance: c.Name,
Config: c,
}),
client: client,
embeddingSpliter: embeddingSpliter,
Base: base,
text: &openaiText{
Base: base,
client: client,
embeddingSpliter: embeddingSpliter,
},
}
}
func (o *openai) String(ctx context.Context, messages []string) (value string, err error) {
func (o *openai) WAV(ctx context.Context, text string, speakers []Speaker) (r io.ReadCloser, err error) {
return nil, errors.New("not supported")
}
type openaiText struct {
*component.Base[Config, struct{}]
client *oai.Client
embeddingSpliter embeddingSpliter
}
func (o *openaiText) String(ctx context.Context, messages []string) (value string, err error) {
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "String")...)
defer func() { telemetry.End(ctx, err) }()
@@ -91,7 +106,7 @@ func (o *openai) String(ctx context.Context, messages []string) (value string, e
return resp.Choices[0].Message.Content, nil
}
func (o *openai) EmbeddingLabels(ctx context.Context, labels model.Labels) (value [][]float32, err error) {
func (o *openaiText) EmbeddingLabels(ctx context.Context, labels model.Labels) (value [][]float32, err error) {
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "EmbeddingLabels")...)
defer func() { telemetry.End(ctx, err) }()
@@ -117,7 +132,7 @@ func (o *openai) EmbeddingLabels(ctx context.Context, labels model.Labels) (valu
return vecs, nil
}
func (o *openai) Embedding(ctx context.Context, s string) (value []float32, err error) {
func (o *openaiText) Embedding(ctx context.Context, s string) (value []float32, err error) {
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "Embedding")...)
defer func() { telemetry.End(ctx, err) }()

View File

@@ -17,9 +17,13 @@ package rewrite
import (
"context"
"fmt"
"html/template"
"io"
"regexp"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/pkg/errors"
@@ -30,10 +34,12 @@ import (
"github.com/glidea/zenfeed/pkg/llm"
"github.com/glidea/zenfeed/pkg/llm/prompt"
"github.com/glidea/zenfeed/pkg/model"
"github.com/glidea/zenfeed/pkg/storage/object"
"github.com/glidea/zenfeed/pkg/telemetry"
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
"github.com/glidea/zenfeed/pkg/util/buffer"
"github.com/glidea/zenfeed/pkg/util/crawl"
hashutil "github.com/glidea/zenfeed/pkg/util/hash"
)
// --- Interface code block ---
@@ -68,7 +74,8 @@ func (c *Config) From(app *config.App) {
}
type Dependencies struct {
LLMFactory llm.Factory
LLMFactory llm.Factory // NOTE: String() with cache.
ObjectStorage object.Storage
}
type Rule struct {
@@ -120,32 +127,98 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen
}
// Transform.
if r.Transform != nil {
if r.Transform.ToText == nil {
return errors.New("to_text is required when transform is set")
if r.Transform != nil { //nolint:nestif
if r.Transform.ToText != nil && r.Transform.ToPodcast != nil {
return errors.New("to_text and to_podcast can not be set at same time")
}
if r.Transform.ToText == nil && r.Transform.ToPodcast == nil {
return errors.New("either to_text or to_podcast must be set when transform is set")
}
switch r.Transform.ToText.Type {
case ToTextTypePrompt:
if r.Transform.ToText.Prompt == "" {
return errors.New("to text prompt is required for prompt type")
if r.Transform.ToText != nil {
switch r.Transform.ToText.Type {
case ToTextTypePrompt:
if r.Transform.ToText.Prompt == "" {
return errors.New("to text prompt is required for prompt type")
}
tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt)
if err != nil {
return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt)
}
buf := buffer.Get()
defer buffer.Put(buf)
if err := tmpl.Execute(buf, prompt.Builtin); err != nil {
return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt)
}
r.Transform.ToText.promptRendered = buf.String()
case ToTextTypeCrawl, ToTextTypeCrawlByJina:
// No specific validation for crawl type here, as the source text itself is the URL.
default:
return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type)
}
tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt)
if err != nil {
return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt)
}
if r.Transform.ToPodcast != nil {
if len(r.Transform.ToPodcast.Speakers) == 0 {
return errors.New("at least one speaker is required for to_podcast")
}
buf := buffer.Get()
defer buffer.Put(buf)
if err := tmpl.Execute(buf, prompt.Builtin); err != nil {
return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt)
}
r.Transform.ToText.promptRendered = buf.String()
r.Transform.ToPodcast.speakers = make([]llm.Speaker, len(r.Transform.ToPodcast.Speakers))
var speakerDescs []string
var speakerNames []string
for i, s := range r.Transform.ToPodcast.Speakers {
if s.Name == "" {
return errors.New("speaker name is required")
}
if s.Voice == "" {
return errors.New("speaker voice is required")
}
r.Transform.ToPodcast.speakers[i] = llm.Speaker{Name: s.Name, Voice: s.Voice}
case ToTextTypeCrawl, ToTextTypeCrawlByJina:
// No specific validation for crawl type here, as the source text itself is the URL.
default:
return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type)
desc := s.Name
if s.Role != "" {
desc += " (" + s.Role + ")"
}
speakerDescs = append(speakerDescs, desc)
speakerNames = append(speakerNames, s.Name)
}
speakersDesc := "- " + strings.Join(speakerDescs, "\n- ")
exampleSpeaker1 := speakerNames[0]
exampleSpeaker2 := exampleSpeaker1
if len(speakerNames) > 1 {
exampleSpeaker2 = speakerNames[1]
}
promptSegments := []string{
"Please convert the following article into a podcast dialogue script.",
"The speakers are:\n" + speakersDesc,
}
if r.Transform.ToPodcast.EstimateMaximumDuration > 0 {
wordsPerMinute := 200
totalMinutes := int(r.Transform.ToPodcast.EstimateMaximumDuration.Minutes())
estimatedWords := totalMinutes * wordsPerMinute
promptSegments = append(promptSegments, fmt.Sprintf("The script should be approximately %d words to fit within a %d-minute duration. If the original content is not sufficient, the script can be shorter as appropriate.", estimatedWords, totalMinutes))
}
if r.Transform.ToPodcast.TranscriptAdditionalPrompt != "" {
promptSegments = append(promptSegments, "Additional instructions: "+r.Transform.ToPodcast.TranscriptAdditionalPrompt)
}
promptSegments = append(promptSegments,
"The output format MUST be a script where each line starts with the speaker's name followed by a colon and a space.",
"Do NOT include any other text, explanations, or formatting before or after the script.",
"Do NOT use background music in the script.",
"Do NOT include any greetings or farewells (e.g., 'Hello everyone', 'Welcome to our show', 'Goodbye').",
fmt.Sprintf("Example of the required format:\n%s: Today we are discussing the article's main points.\n%s: Let's start with the first one.", exampleSpeaker1, exampleSpeaker2),
"Now, convert the article.",
)
r.Transform.ToPodcast.transcriptPrompt = strings.Join(promptSegments, "\n\n")
r.Transform.ToPodcast.speakersDesc = speakersDesc
}
}
@@ -160,9 +233,10 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen
r.matchRE = re
// Action.
switch r.Action {
case "":
if r.Action == "" {
r.Action = ActionCreateOrUpdateLabel
}
switch r.Action {
case ActionCreateOrUpdateLabel:
if r.Label == "" {
return errors.New("label is required for create or update label action")
@@ -179,7 +253,7 @@ func (r *Rule) From(c *config.RewriteRule) {
r.If = c.If
r.SourceLabel = c.SourceLabel
r.SkipTooShortThreshold = c.SkipTooShortThreshold
if c.Transform != nil {
if c.Transform != nil { //nolint:nestif
t := &Transform{}
if c.Transform.ToText != nil {
toText := &ToText{
@@ -192,6 +266,25 @@ func (r *Rule) From(c *config.RewriteRule) {
}
t.ToText = toText
}
if c.Transform.ToPodcast != nil {
toPodcast := &ToPodcast{
LLM: c.Transform.ToPodcast.LLM,
EstimateMaximumDuration: time.Duration(c.Transform.ToPodcast.EstimateMaximumDuration),
TranscriptAdditionalPrompt: c.Transform.ToPodcast.TranscriptAdditionalPrompt,
TTSLLM: c.Transform.ToPodcast.TTSLLM,
}
if toPodcast.EstimateMaximumDuration == 0 {
toPodcast.EstimateMaximumDuration = 3 * time.Minute
}
for _, s := range c.Transform.ToPodcast.Speakers {
toPodcast.Speakers = append(toPodcast.Speakers, Speaker{
Name: s.Name,
Role: s.Role,
Voice: s.Voice,
})
}
t.ToPodcast = toPodcast
}
r.Transform = t
}
r.Match = c.Match
@@ -203,7 +296,8 @@ func (r *Rule) From(c *config.RewriteRule) {
}
type Transform struct {
ToText *ToText
ToText *ToText
ToPodcast *ToPodcast
}
type ToText struct {
@@ -220,6 +314,24 @@ type ToText struct {
promptRendered string
}
type ToPodcast struct {
LLM string
EstimateMaximumDuration time.Duration
TranscriptAdditionalPrompt string
TTSLLM string
Speakers []Speaker
transcriptPrompt string
speakersDesc string
speakers []llm.Speaker
}
type Speaker struct {
Name string
Role string
Voice string
}
type ToTextType string
const (
@@ -310,13 +422,9 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m
}
// Transform text if configured.
text := sourceText
if rule.Transform != nil && rule.Transform.ToText != nil {
transformed, err := r.transformText(ctx, rule.Transform, sourceText)
if err != nil {
return nil, errors.Wrap(err, "transform text")
}
text = transformed
text, err := r.transform(ctx, rule.Transform, sourceText)
if err != nil {
return nil, errors.Wrap(err, "transform")
}
// Check if text matches the rule.
@@ -338,18 +446,34 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m
return labels, nil
}
func (r *rewriter) transform(ctx context.Context, transform *Transform, sourceText string) (string, error) {
if transform == nil {
return sourceText, nil
}
if transform.ToText != nil {
return r.transformText(ctx, transform.ToText, sourceText)
}
if transform.ToPodcast != nil {
return r.transformPodcast(ctx, transform.ToPodcast, sourceText)
}
return sourceText, nil
}
// transformText transforms text using configured LLM or by crawling a URL.
func (r *rewriter) transformText(ctx context.Context, transform *Transform, text string) (string, error) {
switch transform.ToText.Type {
func (r *rewriter) transformText(ctx context.Context, toText *ToText, text string) (string, error) {
switch toText.Type {
case ToTextTypeCrawl:
return r.transformTextCrawl(ctx, r.crawler, text)
case ToTextTypeCrawlByJina:
return r.transformTextCrawl(ctx, r.jinaCrawler, text)
case ToTextTypePrompt:
return r.transformTextPrompt(ctx, transform, text)
return r.transformTextPrompt(ctx, toText, text)
default:
return r.transformTextPrompt(ctx, transform, text)
return r.transformTextPrompt(ctx, toText, text)
}
}
@@ -363,13 +487,13 @@ func (r *rewriter) transformTextCrawl(ctx context.Context, crawler crawl.Crawler
}
// transformTextPrompt transforms text using configured LLM.
func (r *rewriter) transformTextPrompt(ctx context.Context, transform *Transform, text string) (string, error) {
func (r *rewriter) transformTextPrompt(ctx context.Context, toText *ToText, text string) (string, error) {
// Get LLM instance.
llm := r.Dependencies().LLMFactory.Get(transform.ToText.LLM)
llm := r.Dependencies().LLMFactory.Get(toText.LLM)
// Call completion.
result, err := llm.String(ctx, []string{
transform.ToText.promptRendered,
toText.promptRendered,
text, // TODO: may place to first line to hit the model cache in different rewrite rules.
})
if err != nil {
@@ -388,6 +512,71 @@ func (r *rewriter) transformTextHack(text string) string {
return text
}
var audioKey = func(transcript, ext string) string {
hash := hashutil.Sum64(transcript)
file := strconv.FormatUint(hash, 10) + "." + ext
return "podcasts/" + file
}
func (r *rewriter) transformPodcast(ctx context.Context, toPodcast *ToPodcast, sourceText string) (url string, err error) {
transcript, err := r.generateTranscript(ctx, toPodcast, sourceText)
if err != nil {
return "", errors.Wrap(err, "generate podcast transcript")
}
audioKey := audioKey(transcript, "wav")
url, err = r.Dependencies().ObjectStorage.Get(ctx, audioKey)
switch {
case err == nil:
// May canceled at last time by reload, fast return.
return url, nil
case errors.Is(err, object.ErrNotFound):
// Not found, generate new audio.
default:
return "", errors.Wrap(err, "get audio")
}
audioStream, err := r.generateAudio(ctx, toPodcast, transcript)
if err != nil {
return "", errors.Wrap(err, "generate podcast audio")
}
defer func() {
if closeErr := audioStream.Close(); closeErr != nil {
err = errors.Wrap(err, "close audio stream")
}
}()
url, err = r.Dependencies().ObjectStorage.Put(ctx, audioKey, audioStream, "audio/wav")
if err != nil {
return "", errors.Wrap(err, "store podcast audio")
}
return url, nil
}
func (r *rewriter) generateTranscript(ctx context.Context, toPodcast *ToPodcast, sourceText string) (string, error) {
transcript, err := r.Dependencies().LLMFactory.Get(toPodcast.LLM).
String(ctx, []string{toPodcast.transcriptPrompt, sourceText})
if err != nil {
return "", errors.Wrap(err, "llm completion")
}
return toPodcast.speakersDesc +
"\n\nFollowed by the actual dialogue script:\n" +
transcript, nil
}
func (r *rewriter) generateAudio(ctx context.Context, toPodcast *ToPodcast, transcript string) (io.ReadCloser, error) {
audioStream, err := r.Dependencies().LLMFactory.Get(toPodcast.TTSLLM).
WAV(ctx, transcript, toPodcast.speakers)
if err != nil {
return nil, errors.Wrap(err, "calling tts llm")
}
return audioStream, nil
}
type mockRewriter struct {
component.Mock
}

View File

@@ -2,6 +2,8 @@ package rewrite
import (
"context"
"io"
"strings"
"testing"
. "github.com/onsi/gomega"
@@ -12,6 +14,7 @@ import (
"github.com/glidea/zenfeed/pkg/component"
"github.com/glidea/zenfeed/pkg/llm"
"github.com/glidea/zenfeed/pkg/model"
"github.com/glidea/zenfeed/pkg/storage/object"
"github.com/glidea/zenfeed/pkg/test"
)
@@ -19,8 +22,9 @@ func TestLabels(t *testing.T) {
RegisterTestingT(t)
type givenDetail struct {
config *Config
llmMock func(m *mock.Mock)
config *Config
llmMock func(m *mock.Mock)
objectStorageMock func(m *mock.Mock)
}
type whenDetail struct {
inputLabels model.Labels
@@ -173,7 +177,7 @@ func TestLabels(t *testing.T) {
},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform text: llm completion: LLM failed"),
err: errors.New("transform: llm completion: LLM failed"),
isErr: true,
},
},
@@ -220,22 +224,163 @@ func TestLabels(t *testing.T) {
isErr: false,
},
},
{
Scenario: "Successfully generate podcast from content",
Given: "a rule to convert content to a podcast with all dependencies mocked to succeed",
When: "processing labels with content to be converted to a podcast",
Then: "should return labels with a new podcast_url label",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{
LLM: "mock-llm-transcript",
TTSLLM: "mock-llm-tts",
Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}},
},
},
Action: ActionCreateOrUpdateLabel,
Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("This is the podcast script.", nil).Once()
m.On("WAV", mock.Anything, mock.Anything, mock.AnythingOfType("[]llm.Speaker")).
Return(io.NopCloser(strings.NewReader("fake audio data")), nil).Once()
},
objectStorageMock: func(m *mock.Mock) {
m.On("Put", mock.Anything, mock.AnythingOfType("string"), mock.Anything, "audio/wav").
Return("http://storage.example.com/podcast.wav", nil).Once()
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
},
},
WhenDetail: whenDetail{
inputLabels: model.Labels{
{Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."},
},
},
ThenExpected: thenExpected{
outputLabels: model.Labels{
{Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."},
{Key: "podcast_url", Value: "http://storage.example.com/podcast.wav"},
},
isErr: false,
},
},
{
Scenario: "Fail podcast generation due to transcription LLM error",
Given: "a rule to convert content to a podcast, but the transcription LLM is mocked to fail",
When: "processing labels",
Then: "should return an error related to transcription failure",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
},
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("", errors.New("transcript failed")).Once()
},
},
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform: generate podcast transcript: llm completion: transcript failed"),
isErr: true,
},
},
{
Scenario: "Fail podcast generation due to TTS LLM error",
Given: "a rule to convert content to a podcast, but the TTS LLM is mocked to fail",
When: "processing labels",
Then: "should return an error related to TTS failure",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
},
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once()
m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("tts failed")).Once()
},
objectStorageMock: func(m *mock.Mock) {
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
},
},
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform: generate podcast audio: calling tts llm: tts failed"),
isErr: true,
},
},
{
Scenario: "Fail podcast generation due to object storage error",
Given: "a rule to convert content to a podcast, but object storage is mocked to fail",
When: "processing labels",
Then: "should return an error related to storage failure",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
},
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once()
m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(io.NopCloser(strings.NewReader("fake audio")), nil).Once()
},
objectStorageMock: func(m *mock.Mock) {
m.On("Put", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return("", errors.New("storage failed")).Once()
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
},
},
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform: store podcast audio: storage failed"),
isErr: true,
},
},
}
for _, tt := range tests {
t.Run(tt.Scenario, func(t *testing.T) {
// Given.
var mockLLMFactory llm.Factory
var mockInstance *mock.Mock // Store the mock instance for assertion
// Create mock factory and capture the mock.Mock instance.
mockOption := component.MockOption(func(m *mock.Mock) {
mockInstance = m // Capture the mock instance.
var mockLLMInstance *mock.Mock
llmMockOption := component.MockOption(func(m *mock.Mock) {
mockLLMInstance = m
if tt.GivenDetail.llmMock != nil {
tt.GivenDetail.llmMock(m)
}
})
mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, mockOption) // Use the factory directly with the option
mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, llmMockOption)
Expect(err).NotTo(HaveOccurred())
var mockObjectStorage object.Storage
var mockObjectStorageInstance *mock.Mock
objectStorageMockOption := component.MockOption(func(m *mock.Mock) {
mockObjectStorageInstance = m
if tt.GivenDetail.objectStorageMock != nil {
tt.GivenDetail.objectStorageMock(m)
}
})
mockObjectStorageFactory := object.NewFactory(objectStorageMockOption)
mockObjectStorage, err = mockObjectStorageFactory.New("test", nil, object.Dependencies{})
Expect(err).NotTo(HaveOccurred())
// Manually validate config to compile regex and render templates.
@@ -252,7 +397,8 @@ func TestLabels(t *testing.T) {
Instance: "test",
Config: tt.GivenDetail.config,
Dependencies: Dependencies{
LLMFactory: mockLLMFactory, // Pass the mock factory
LLMFactory: mockLLMFactory, // Pass the mock factory
ObjectStorage: mockObjectStorage,
},
}),
}
@@ -280,10 +426,12 @@ func TestLabels(t *testing.T) {
Expect(outputLabels).To(Equal(tt.ThenExpected.outputLabels))
}
// Verify LLM calls if stubs were provided.
if tt.GivenDetail.llmMock != nil && mockInstance != nil {
// Assert expectations on the captured mock instance.
mockInstance.AssertExpectations(t)
// Verify mock calls if stubs were provided.
if tt.GivenDetail.llmMock != nil && mockLLMInstance != nil {
mockLLMInstance.AssertExpectations(t)
}
if tt.GivenDetail.objectStorageMock != nil && mockObjectStorageInstance != nil {
mockObjectStorageInstance.AssertExpectations(t)
}
})
}

View File

@@ -216,6 +216,10 @@ func (m *manager) reload(config *Config) (err error) {
func (m *manager) runOrRestartScrapers(config *Config, newScrapers map[string]scraper.Scraper) error {
for i := range config.Scrapers {
c := &config.Scrapers[i]
if err := c.Validate(); err != nil {
return errors.Wrapf(err, "validate scraper %s", c.Name)
}
if err := m.runOrRestartScraper(c, newScrapers); err != nil {
return errors.Wrapf(err, "run or restart scraper %s", c.Name)
}

View File

@@ -69,6 +69,11 @@ func (c *Config) Validate() error {
if c.Name == "" {
return errors.New("name cannot be empty")
}
if c.RSS != nil {
if err := c.RSS.Validate(); err != nil {
return errors.Wrap(err, "invalid RSS config")
}
}
return nil
}

View File

@@ -244,7 +244,7 @@ func TestNew(t *testing.T) {
WhenDetail: whenDetail{},
ThenExpected: thenExpected{
isErr: true,
wantErrMsg: "creating source: invalid RSS config: URL must be a valid HTTP/HTTPS URL", // Error from newRSSReader via newReader
wantErrMsg: "invalid RSS config: URL must be a valid HTTP/HTTPS URL", // Error from newRSSReader via newReader
},
},
{
@@ -264,7 +264,7 @@ func TestNew(t *testing.T) {
WhenDetail: whenDetail{},
ThenExpected: thenExpected{
isErr: true,
wantErrMsg: "creating source: source not supported", // Error from newReader
wantErrMsg: "source not supported", // Error from newReader
},
},
}

View File

@@ -0,0 +1,239 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package object
import (
"context"
"io"
"net/url"
"reflect"
"strings"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/pkg/errors"
"github.com/glidea/zenfeed/pkg/component"
"github.com/glidea/zenfeed/pkg/config"
"github.com/glidea/zenfeed/pkg/telemetry"
"github.com/glidea/zenfeed/pkg/telemetry/log"
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
)
// --- Interface code block ---
type Storage interface {
component.Component
config.Watcher
Put(ctx context.Context, key string, body io.Reader, contentType string) (url string, err error)
Get(ctx context.Context, key string) (url string, err error)
}
var ErrNotFound = errors.New("not found")
type Config struct {
Endpoint string
AccessKeyID string
SecretAccessKey string
Bucket string
BucketURL string
}
func (c *Config) Validate() error {
if c.Endpoint == "" {
return errors.New("endpoint is required")
}
c.Endpoint = strings.TrimPrefix(c.Endpoint, "https://") // S3 endpoint should not have https:// prefix.
c.Endpoint = strings.TrimPrefix(c.Endpoint, "http://")
if c.AccessKeyID == "" {
return errors.New("access key id is required")
}
if c.SecretAccessKey == "" {
return errors.New("secret access key is required")
}
if c.Bucket == "" {
return errors.New("bucket is required")
}
if c.BucketURL == "" {
return errors.New("bucket url is required")
}
return nil
}
func (c *Config) From(app *config.App) *Config {
*c = Config{
Endpoint: app.Storage.Object.Endpoint,
AccessKeyID: app.Storage.Object.AccessKeyID,
SecretAccessKey: app.Storage.Object.SecretAccessKey,
Bucket: app.Storage.Object.Bucket,
BucketURL: app.Storage.Object.BucketURL,
}
return c
}
type Dependencies struct{}
// --- Factory code block ---
type Factory component.Factory[Storage, config.App, Dependencies]
func NewFactory(mockOn ...component.MockOption) Factory {
if len(mockOn) > 0 {
return component.FactoryFunc[Storage, config.App, Dependencies](
func(instance string, config *config.App, dependencies Dependencies) (Storage, error) {
m := &mockStorage{}
component.MockOptions(mockOn).Apply(&m.Mock)
return m, nil
},
)
}
return component.FactoryFunc[Storage, config.App, Dependencies](new)
}
func new(instance string, app *config.App, dependencies Dependencies) (Storage, error) {
config := &Config{}
config.From(app)
if err := config.Validate(); err != nil {
return nil, errors.Wrap(err, "validate config")
}
client, err := minio.New(config.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(config.AccessKeyID, config.SecretAccessKey, ""),
Secure: true,
})
if err != nil {
return nil, errors.Wrap(err, "new minio client")
}
u, err := url.Parse(config.BucketURL)
if err != nil {
return nil, errors.Wrap(err, "parse public url")
}
return &s3{
Base: component.New(&component.BaseConfig[Config, Dependencies]{
Name: "ObjectStorage",
Instance: instance,
Config: config,
Dependencies: dependencies,
}),
client: client,
bucketURL: u,
}, nil
}
// --- Implementation code block ---
type s3 struct {
*component.Base[Config, Dependencies]
client *minio.Client
bucketURL *url.URL
}
func (s *s3) Put(ctx context.Context, key string, body io.Reader, contentType string) (publicURL string, err error) {
ctx = telemetry.StartWith(ctx, append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Put")...)
defer func() { telemetry.End(ctx, err) }()
bucket := s.Config().Bucket
if _, err := s.client.PutObject(ctx, bucket, key, body, -1, minio.PutObjectOptions{
ContentType: contentType,
}); err != nil {
return "", errors.Wrap(err, "put object")
}
return s.bucketURL.JoinPath(key).String(), nil
}
func (s *s3) Get(ctx context.Context, key string) (publicURL string, err error) {
ctx = telemetry.StartWith(ctx, append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Get")...)
defer func() { telemetry.End(ctx, err) }()
bucket := s.Config().Bucket
if _, err := s.client.StatObject(ctx, bucket, key, minio.StatObjectOptions{}); err != nil {
errResponse := minio.ToErrorResponse(err)
if errResponse.Code == minio.NoSuchKey {
return "", ErrNotFound
}
return "", errors.Wrap(err, "stat object")
}
return s.bucketURL.JoinPath(key).String(), nil
}
func (s *s3) Reload(app *config.App) (err error) {
ctx := telemetry.StartWith(s.Context(), append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Reload")...)
defer func() { telemetry.End(ctx, err) }()
newConfig := &Config{}
newConfig.From(app)
if err := newConfig.Validate(); err != nil {
return errors.Wrap(err, "validate config")
}
if reflect.DeepEqual(s.Config(), newConfig) {
log.Debug(ctx, "object storage config not changed")
return nil
}
client, err := minio.New(newConfig.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(newConfig.AccessKeyID, newConfig.SecretAccessKey, ""),
Secure: true,
})
if err != nil {
return errors.Wrap(err, "new minio client")
}
u, err := url.Parse(newConfig.BucketURL)
if err != nil {
return errors.Wrap(err, "parse public url")
}
s.client = client
s.bucketURL = u
s.SetConfig(newConfig)
log.Info(ctx, "object storage reloaded")
return nil
}
// --- Mock code block ---
type mockStorage struct {
component.Mock
}
func (m *mockStorage) Put(ctx context.Context, key string, body io.Reader, contentType string) (string, error) {
args := m.Called(ctx, key, body, contentType)
return args.String(0), args.Error(1)
}
func (m *mockStorage) Get(ctx context.Context, key string) (string, error) {
args := m.Called(ctx, key)
return args.String(0), args.Error(1)
}
func (m *mockStorage) Reload(app *config.App) error {
args := m.Called(app)
return args.Error(0)
}

View File

@@ -122,6 +122,32 @@ func ReadUint32(r io.Reader) (uint32, error) {
return binary.LittleEndian.Uint32(b), nil
}
// WriteUint16 writes a uint16 using a pooled buffer.
func WriteUint16(w io.Writer, v uint16) error {
bp := smallBufPool.Get().(*[]byte)
defer smallBufPool.Put(bp)
b := *bp
binary.LittleEndian.PutUint16(b, v)
_, err := w.Write(b[:2])
return err
}
// ReadUint16 reads a uint16 using a pooled buffer.
func ReadUint16(r io.Reader) (uint16, error) {
bp := smallBufPool.Get().(*[]byte)
defer smallBufPool.Put(bp)
b := (*bp)[:2]
// Read exactly 2 bytes into the slice.
if _, err := io.ReadFull(r, b); err != nil {
return 0, errors.Wrap(err, "read uint16")
}
return binary.LittleEndian.Uint16(b), nil
}
// WriteFloat32 writes a float32 using a pooled buffer.
func WriteFloat32(w io.Writer, v float32) error {
return WriteUint32(w, math.Float32bits(v))

100
pkg/util/wav/wav.go Normal file
View File

@@ -0,0 +1,100 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package wav
import (
"io"
"github.com/pkg/errors"
binaryutil "github.com/glidea/zenfeed/pkg/util/binary"
)
// Header contains the WAV header information.
type Header struct {
SampleRate uint32
BitDepth uint16
NumChannels uint16
}
// WriteHeader writes the WAV header to a writer.
// pcmDataSize is the size of the raw PCM data.
func WriteHeader(w io.Writer, h *Header, pcmDataSize uint32) error {
// RIFF Header.
if err := writeRIFFHeader(w, pcmDataSize); err != nil {
return errors.Wrap(err, "write RIFF header")
}
// fmt chunk.
if err := writeFMTChunk(w, h); err != nil {
return errors.Wrap(err, "write fmt chunk")
}
// data chunk.
if _, err := w.Write([]byte("data")); err != nil {
return errors.Wrap(err, "write data chunk marker")
}
if err := binaryutil.WriteUint32(w, pcmDataSize); err != nil {
return errors.Wrap(err, "write pcm data size")
}
return nil
}
func writeRIFFHeader(w io.Writer, pcmDataSize uint32) error {
if _, err := w.Write([]byte("RIFF")); err != nil {
return errors.Wrap(err, "write RIFF")
}
if err := binaryutil.WriteUint32(w, uint32(36+pcmDataSize)); err != nil {
return errors.Wrap(err, "write file size")
}
if _, err := w.Write([]byte("WAVE")); err != nil {
return errors.Wrap(err, "write WAVE")
}
return nil
}
func writeFMTChunk(w io.Writer, h *Header) error {
if _, err := w.Write([]byte("fmt ")); err != nil {
return errors.Wrap(err, "write fmt")
}
if err := binaryutil.WriteUint32(w, uint32(16)); err != nil { // PCM chunk size.
return errors.Wrap(err, "write pcm chunk size")
}
if err := binaryutil.WriteUint16(w, uint16(1)); err != nil { // PCM format.
return errors.Wrap(err, "write pcm format")
}
if err := binaryutil.WriteUint16(w, h.NumChannels); err != nil {
return errors.Wrap(err, "write num channels")
}
if err := binaryutil.WriteUint32(w, h.SampleRate); err != nil {
return errors.Wrap(err, "write sample rate")
}
byteRate := h.SampleRate * uint32(h.NumChannels) * uint32(h.BitDepth) / 8
if err := binaryutil.WriteUint32(w, byteRate); err != nil {
return errors.Wrap(err, "write byte rate")
}
blockAlign := h.NumChannels * h.BitDepth / 8
if err := binaryutil.WriteUint16(w, blockAlign); err != nil {
return errors.Wrap(err, "write block align")
}
if err := binaryutil.WriteUint16(w, h.BitDepth); err != nil {
return errors.Wrap(err, "write bit depth")
}
return nil
}

161
pkg/util/wav/wav_test.go Normal file
View File

@@ -0,0 +1,161 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package wav
import (
"bytes"
"testing"
. "github.com/onsi/gomega"
"github.com/glidea/zenfeed/pkg/test"
)
func TestWriteHeader(t *testing.T) {
RegisterTestingT(t)
type givenDetail struct{}
type whenDetail struct {
header *Header
pcmDataSize uint32
}
type thenExpected struct {
expectedBytes []byte
expectError bool
}
tests := []test.Case[givenDetail, whenDetail, thenExpected]{
{
Scenario: "Standard CD quality audio",
Given: "a header for CD quality audio and a non-zero data size",
When: "writing the header",
Then: "should produce a valid 44-byte WAV header and no error",
GivenDetail: givenDetail{},
WhenDetail: whenDetail{
header: &Header{
SampleRate: 44100,
BitDepth: 16,
NumChannels: 2,
},
pcmDataSize: 176400,
},
ThenExpected: thenExpected{
expectedBytes: []byte{
'R', 'I', 'F', 'F',
0x34, 0xB1, 0x02, 0x00, // ChunkSize = 36 + 176400 = 176436
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16
0x01, 0x00, // AudioFormat = 1 (PCM)
0x02, 0x00, // NumChannels = 2
0x44, 0xAC, 0x00, 0x00, // SampleRate = 44100
0x10, 0xB1, 0x02, 0x00, // ByteRate = 176400
0x04, 0x00, // BlockAlign = 4
0x10, 0x00, // BitsPerSample = 16
'd', 'a', 't', 'a',
0x10, 0xB1, 0x02, 0x00, // Subchunk2Size = 176400
},
expectError: false,
},
},
{
Scenario: "Mono audio for speech",
Given: "a header for mono speech audio and a non-zero data size",
When: "writing the header",
Then: "should produce a valid 44-byte WAV header and no error",
GivenDetail: givenDetail{},
WhenDetail: whenDetail{
header: &Header{
SampleRate: 16000,
BitDepth: 16,
NumChannels: 1,
},
pcmDataSize: 32000,
},
ThenExpected: thenExpected{
expectedBytes: []byte{
'R', 'I', 'F', 'F',
0x24, 0x7D, 0x00, 0x00, // ChunkSize = 36 + 32000 = 32036
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16
0x01, 0x00, // AudioFormat = 1
0x01, 0x00, // NumChannels = 1
0x80, 0x3E, 0x00, 0x00, // SampleRate = 16000
0x00, 0x7D, 0x00, 0x00, // ByteRate = 32000
0x02, 0x00, // BlockAlign = 2
0x10, 0x00, // BitsPerSample = 16
'd', 'a', 't', 'a',
0x00, 0x7D, 0x00, 0x00, // Subchunk2Size = 32000
},
expectError: false,
},
},
{
Scenario: "8-bit mono audio with zero data size",
Given: "a header for 8-bit mono audio and a zero data size",
When: "writing the header for an empty file",
Then: "should produce a valid 44-byte WAV header with data size 0",
GivenDetail: givenDetail{},
WhenDetail: whenDetail{
header: &Header{
SampleRate: 8000,
BitDepth: 8,
NumChannels: 1,
},
pcmDataSize: 0,
},
ThenExpected: thenExpected{
expectedBytes: []byte{
'R', 'I', 'F', 'F',
0x24, 0x00, 0x00, 0x00, // ChunkSize = 36 + 0 = 36
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16
0x01, 0x00, // AudioFormat = 1
0x01, 0x00, // NumChannels = 1
0x40, 0x1F, 0x00, 0x00, // SampleRate = 8000
0x40, 0x1F, 0x00, 0x00, // ByteRate = 8000
0x01, 0x00, // BlockAlign = 1
0x08, 0x00, // BitsPerSample = 8
'd', 'a', 't', 'a',
0x00, 0x00, 0x00, 0x00, // Subchunk2Size = 0
},
expectError: false,
},
},
}
for _, tt := range tests {
t.Run(tt.Scenario, func(t *testing.T) {
// Given.
var buf bytes.Buffer
// When.
err := WriteHeader(&buf, tt.WhenDetail.header, tt.WhenDetail.pcmDataSize)
// Then.
if tt.ThenExpected.expectError {
Expect(err).To(HaveOccurred())
} else {
Expect(err).NotTo(HaveOccurred())
Expect(buf.Bytes()).To(Equal(tt.ThenExpected.expectedBytes))
}
})
}
}