add podcast

This commit is contained in:
glidea
2025-07-08 18:13:26 +08:00
parent 263fcbbfaf
commit 2de0cf77fc
21 changed files with 1545 additions and 145 deletions

View File

@@ -90,6 +90,7 @@ type LLM struct {
APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty" desc:"The API key of the LLM. It is required when api.llm is set."`
Model string `yaml:"model,omitempty" json:"model,omitempty" desc:"The model of the LLM. e.g. gpt-4o-mini. Can not be empty with embedding_model at same time when api.llm is set."`
EmbeddingModel string `yaml:"embedding_model,omitempty" json:"embedding_model,omitempty" desc:"The embedding model of the LLM. e.g. text-embedding-3-small. Can not be empty with model at same time when api.llm is set. NOTE: Once used, do not modify it directly, instead, add a new LLM configuration."`
TTSModel string `yaml:"tts_model,omitempty" json:"tts_model,omitempty" desc:"The TTS model of the LLM."`
Temperature float32 `yaml:"temperature,omitempty" json:"temperature,omitempty" desc:"The temperature (0-2) of the LLM. Default: 0.0"`
}
@@ -101,8 +102,9 @@ type Scrape struct {
}
type Storage struct {
Dir string `yaml:"dir,omitempty" json:"dir,omitempty" desc:"The base directory of the all storages. Default: ./data. It can not be changed after the app is running."`
Feed FeedStorage `yaml:"feed,omitempty" json:"feed,omitempty" desc:"The feed storage config."`
Dir string `yaml:"dir,omitempty" json:"dir,omitempty" desc:"The base directory of the all storages. Default: ./data. It can not be changed after the app is running."`
Feed FeedStorage `yaml:"feed,omitempty" json:"feed,omitempty" desc:"The feed storage config."`
Object ObjectStorage `yaml:"object,omitempty" json:"object,omitempty" desc:"The object storage config."`
}
type FeedStorage struct {
@@ -113,6 +115,14 @@ type FeedStorage struct {
BlockDuration timeutil.Duration `yaml:"block_duration,omitempty" json:"block_duration,omitempty" desc:"How long to keep the feed storage block. Block is time-based, like Prometheus TSDB Block. Default: 25h"`
}
type ObjectStorage struct {
Endpoint string `yaml:"endpoint,omitempty" json:"endpoint,omitempty" desc:"The endpoint of the object storage."`
AccessKeyID string `yaml:"access_key_id,omitempty" json:"access_key_id,omitempty" desc:"The access key id of the object storage."`
SecretAccessKey string `yaml:"secret_access_key,omitempty" json:"secret_access_key,omitempty" desc:"The secret access key of the object storage."`
Bucket string `yaml:"bucket,omitempty" json:"bucket,omitempty" desc:"The bucket of the object storage."`
BucketURL string `yaml:"bucket_url,omitempty" json:"bucket_url,omitempty" desc:"The public URL of the object storage bucket."`
}
type ScrapeSource struct {
Interval timeutil.Duration `yaml:"interval,omitempty" json:"interval,omitempty" desc:"How often to scrape this source. Default: global interval"`
Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the source. It is required."`
@@ -137,7 +147,8 @@ type RewriteRule struct {
}
type RewriteRuleTransform struct {
ToText *RewriteRuleTransformToText `yaml:"to_text,omitempty" json:"to_text,omitempty" desc:"The transform config to transform the source text to text."`
ToText *RewriteRuleTransformToText `yaml:"to_text,omitempty" json:"to_text,omitempty" desc:"The transform config to transform the source text to text."`
ToPodcast *RewriteRuleTransformToPodcast `yaml:"to_podcast,omitempty" json:"to_podcast,omitempty" desc:"The transform config to transform the source text to podcast."`
}
type RewriteRuleTransformToText struct {
@@ -146,6 +157,20 @@ type RewriteRuleTransformToText struct {
Prompt string `yaml:"prompt,omitempty" json:"prompt,omitempty" desc:"The prompt to transform the source text. The source text will be injected into the prompt above. And you can use go template syntax to refer some built-in prompts, like {{ .summary }}. Available built-in prompts: category, tags, score, comment_confucius, summary, summary_html_snippet."`
}
type RewriteRuleTransformToPodcast struct {
LLM string `yaml:"llm,omitempty" json:"llm,omitempty" desc:"The LLM name to use. Default is the default LLM in llms section."`
EstimateMaximumDuration timeutil.Duration `yaml:"estimate_maximum_duration,omitempty" json:"estimate_maximum_duration,omitempty" desc:"The estimated maximum duration of the podcast. It will affect the length of the generated transcript. e.g. 5m. Default is 5m."`
TranscriptAdditionalPrompt string `yaml:"transcript_additional_prompt,omitempty" json:"transcript_additional_prompt,omitempty" desc:"The additional prompt to add to the transcript. It is optional."`
TTSLLM string `yaml:"tts_llm,omitempty" json:"tts_llm,omitempty" desc:"The LLM name to use for TTS. Only supports gemini now. Default is the default LLM in llms section."`
Speakers []RewriteRuleTransformToPodcastSpeaker `yaml:"speakers,omitempty" json:"speakers,omitempty" desc:"The speakers to use. It is required, at least one speaker is needed."`
}
type RewriteRuleTransformToPodcastSpeaker struct {
Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the speaker. It is required."`
Role string `yaml:"role,omitempty" json:"role,omitempty" desc:"The role description of the speaker. You can think of it as a character setting."`
Voice string `yaml:"voice,omitempty" json:"voice,omitempty" desc:"The voice of the speaker. It is required."`
}
type SchedulsRule struct {
Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the rule. It is required."`
Query string `yaml:"query,omitempty" json:"query,omitempty" desc:"The semantic query to get the feeds. NOTE it is optional"`

248
pkg/llm/gemini.go Normal file
View File

@@ -0,0 +1,248 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package llm
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"path/filepath"
"github.com/pkg/errors"
oai "github.com/sashabaranov/go-openai"
"github.com/glidea/zenfeed/pkg/component"
"github.com/glidea/zenfeed/pkg/telemetry"
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
"github.com/glidea/zenfeed/pkg/util/wav"
)
type gemini struct {
*component.Base[Config, struct{}]
text
hc *http.Client
embeddingSpliter embeddingSpliter
}
func newGemini(c *Config) LLM {
config := oai.DefaultConfig(c.APIKey)
config.BaseURL = filepath.Join(c.Endpoint, "openai") // OpenAI compatible endpoint.
client := oai.NewClientWithConfig(config)
embeddingSpliter := newEmbeddingSpliter(1536, 64)
base := component.New(&component.BaseConfig[Config, struct{}]{
Name: "LLM/gemini",
Instance: c.Name,
Config: c,
})
return &gemini{
Base: base,
text: &openaiText{
Base: base,
client: client,
},
hc: &http.Client{},
embeddingSpliter: embeddingSpliter,
}
}
func (g *gemini) WAV(ctx context.Context, text string, speakers []Speaker) (r io.ReadCloser, err error) {
ctx = telemetry.StartWith(ctx, append(g.TelemetryLabels(), telemetrymodel.KeyOperation, "WAV")...)
defer func() { telemetry.End(ctx, err) }()
if g.Config().TTSModel == "" {
return nil, errors.New("tts model is not set")
}
reqPayload, err := buildWAVRequestPayload(text, speakers)
if err != nil {
return nil, errors.Wrap(err, "build wav request payload")
}
pcmData, err := g.doWAVRequest(ctx, reqPayload)
if err != nil {
return nil, errors.Wrap(err, "do wav request")
}
return streamWAV(pcmData), nil
}
func (g *gemini) doWAVRequest(ctx context.Context, reqPayload *geminiRequest) ([]byte, error) {
config := g.Config()
body, err := json.Marshal(reqPayload)
if err != nil {
return nil, errors.Wrap(err, "marshal tts request")
}
url := config.Endpoint + "/models/" + config.TTSModel + ":generateContent"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "new tts request")
}
req.Header.Set("x-goog-api-key", config.APIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := g.hc.Do(req)
if err != nil {
return nil, errors.Wrap(err, "do tts request")
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
errMsg, _ := io.ReadAll(resp.Body)
return nil, errors.Errorf("tts request failed with status %d: %s", resp.StatusCode, string(errMsg))
}
var ttsResp geminiResponse
if err := json.NewDecoder(resp.Body).Decode(&ttsResp); err != nil {
return nil, errors.Wrap(err, "decode tts response")
}
if len(ttsResp.Candidates) == 0 || len(ttsResp.Candidates[0].Content.Parts) == 0 || ttsResp.Candidates[0].Content.Parts[0].InlineData == nil {
return nil, errors.New("no audio data in tts response")
}
audioDataB64 := ttsResp.Candidates[0].Content.Parts[0].InlineData.Data
pcmData, err := base64.StdEncoding.DecodeString(audioDataB64)
if err != nil {
return nil, errors.Wrap(err, "decode base64")
}
return pcmData, nil
}
func buildWAVRequestPayload(text string, speakers []Speaker) (*geminiRequest, error) {
reqPayload := geminiRequest{
Contents: []*geminiRequestContent{{Parts: []*geminiRequestPart{{Text: text}}}},
Config: &geminiRequestConfig{
ResponseModalities: []string{"AUDIO"},
SpeechConfig: &geminiRequestSpeechConfig{},
},
}
switch len(speakers) {
case 0:
return nil, errors.New("no speakers")
case 1:
reqPayload.Config.SpeechConfig.VoiceConfig = &geminiRequestVoiceConfig{
PrebuiltVoiceConfig: &geminiRequestPrebuiltVoiceConfig{VoiceName: speakers[0].Voice},
}
default:
multiSpeakerConfig := &geminiRequestMultiSpeakerVoiceConfig{}
for _, s := range speakers {
multiSpeakerConfig.SpeakerVoiceConfigs = append(multiSpeakerConfig.SpeakerVoiceConfigs, &geminiRequestSpeakerVoiceConfig{
Speaker: s.Name,
VoiceConfig: &geminiRequestVoiceConfig{
PrebuiltVoiceConfig: &geminiRequestPrebuiltVoiceConfig{VoiceName: s.Voice},
},
})
}
reqPayload.Config.SpeechConfig.MultiSpeakerVoiceConfig = multiSpeakerConfig
}
return &reqPayload, nil
}
func streamWAV(pcmData []byte) io.ReadCloser {
pipeReader, pipeWriter := io.Pipe()
go func() {
defer func() { _ = pipeWriter.Close() }()
if err := wav.WriteHeader(pipeWriter, geminiWavHeader, uint32(len(pcmData))); err != nil {
pipeWriter.CloseWithError(errors.Wrap(err, "write wav header"))
return
}
if _, err := io.Copy(pipeWriter, bytes.NewReader(pcmData)); err != nil {
pipeWriter.CloseWithError(errors.Wrap(err, "write pcm data"))
return
}
}()
return pipeReader
}
var geminiWavHeader = &wav.Header{
SampleRate: 24000,
BitDepth: 16,
NumChannels: 1,
}
type geminiRequest struct {
Contents []*geminiRequestContent `json:"contents"`
Config *geminiRequestConfig `json:"generationConfig"`
}
type geminiRequestContent struct {
Parts []*geminiRequestPart `json:"parts"`
}
type geminiRequestPart struct {
Text string `json:"text"`
}
type geminiRequestConfig struct {
ResponseModalities []string `json:"responseModalities"`
SpeechConfig *geminiRequestSpeechConfig `json:"speechConfig"`
}
type geminiRequestSpeechConfig struct {
VoiceConfig *geminiRequestVoiceConfig `json:"voiceConfig,omitempty"`
MultiSpeakerVoiceConfig *geminiRequestMultiSpeakerVoiceConfig `json:"multiSpeakerVoiceConfig,omitempty"`
}
type geminiRequestVoiceConfig struct {
PrebuiltVoiceConfig *geminiRequestPrebuiltVoiceConfig `json:"prebuiltVoiceConfig,omitempty"`
}
type geminiRequestPrebuiltVoiceConfig struct {
VoiceName string `json:"voiceName,omitempty"`
}
type geminiRequestMultiSpeakerVoiceConfig struct {
SpeakerVoiceConfigs []*geminiRequestSpeakerVoiceConfig `json:"speakerVoiceConfigs,omitempty"`
}
type geminiRequestSpeakerVoiceConfig struct {
Speaker string `json:"speaker,omitempty"`
VoiceConfig *geminiRequestVoiceConfig `json:"voiceConfig,omitempty"`
}
type geminiResponse struct {
Candidates []*geminiResponseCandidate `json:"candidates"`
}
type geminiResponseCandidate struct {
Content *geminiResponseContent `json:"content"`
}
type geminiResponseContent struct {
Parts []*geminiResponsePart `json:"parts"`
}
type geminiResponsePart struct {
InlineData *geminiResponseInlineData `json:"inlineData"`
}
type geminiResponseInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"` // Base64 encoded.
}

View File

@@ -18,6 +18,7 @@ package llm
import (
"bytes"
"context"
"io"
"reflect"
"strconv"
"sync"
@@ -42,19 +43,33 @@ import (
// --- Interface code block ---
type LLM interface {
component.Component
text
audio
}
type text interface {
String(ctx context.Context, messages []string) (string, error)
EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error)
Embedding(ctx context.Context, text string) ([]float32, error)
}
type audio interface {
WAV(ctx context.Context, text string, speakers []Speaker) (io.ReadCloser, error)
}
type Speaker struct {
Name string
Voice string
}
type Config struct {
Name string
Default bool
Provider ProviderType
Endpoint string
APIKey string
Model, EmbeddingModel string
Temperature float32
Name string
Default bool
Provider ProviderType
Endpoint string
APIKey string
Model, EmbeddingModel, TTSModel string
Temperature float32
}
type ProviderType string
@@ -72,7 +87,7 @@ var defaultEndpoints = map[ProviderType]string{
ProviderTypeOpenAI: "https://api.openai.com/v1",
ProviderTypeOpenRouter: "https://openrouter.ai/api/v1",
ProviderTypeDeepSeek: "https://api.deepseek.com/v1",
ProviderTypeGemini: "https://generativelanguage.googleapis.com/v1beta/openai",
ProviderTypeGemini: "https://generativelanguage.googleapis.com/v1beta",
ProviderTypeVolc: "https://ark.cn-beijing.volces.com/api/v3",
ProviderTypeSiliconFlow: "https://api.siliconflow.cn/v1",
}
@@ -97,8 +112,8 @@ func (c *Config) Validate() error { //nolint:cyclop
if c.APIKey == "" {
return errors.New("api key is required")
}
if c.Model == "" && c.EmbeddingModel == "" {
return errors.New("model or embedding model is required")
if c.Model == "" && c.EmbeddingModel == "" && c.TTSModel == "" {
return errors.New("model or embedding model or tts model is required")
}
if c.Temperature < 0 || c.Temperature > 2 {
return errors.Errorf("invalid temperature: %f, should be in range [0, 2]", c.Temperature)
@@ -182,6 +197,7 @@ func (c *FactoryConfig) From(app *config.App) {
APIKey: llm.APIKey,
Model: llm.Model,
EmbeddingModel: llm.EmbeddingModel,
TTSModel: llm.TTSModel,
Temperature: llm.Temperature,
})
}
@@ -207,12 +223,9 @@ func NewFactory(
) (Factory, error) {
if len(mockOn) > 0 {
mf := &mockFactory{}
getCall := mf.On("Get", mock.Anything)
getCall.Run(func(args mock.Arguments) {
m := &mockLLM{}
component.MockOptions(mockOn).Apply(&m.Mock)
getCall.Return(m, nil)
})
m := &mockLLM{}
component.MockOptions(mockOn).Apply(&m.Mock)
mf.On("Get", mock.Anything).Return(m)
mf.On("Reload", mock.Anything).Return(nil)
return mf, nil
@@ -307,11 +320,6 @@ func (f *factory) Get(name string) LLM {
continue
}
if f.llms[name] == nil {
llm := f.new(&llmC)
f.llms[name] = llm
}
return f.llms[name]
}
@@ -320,8 +328,12 @@ func (f *factory) Get(name string) LLM {
func (f *factory) new(c *Config) LLM {
switch c.Provider {
case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek, ProviderTypeGemini, ProviderTypeVolc, ProviderTypeSiliconFlow: //nolint:lll
case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek, ProviderTypeVolc, ProviderTypeSiliconFlow: //nolint:lll
return newCached(newOpenAI(c), f.Dependencies().KVStorage)
case ProviderTypeGemini:
return newCached(newGemini(c), f.Dependencies().KVStorage)
default:
return newCached(newOpenAI(c), f.Dependencies().KVStorage)
}
@@ -333,14 +345,17 @@ func (f *factory) initLLMs() {
llms = make(map[string]LLM, len(config.LLMs))
defaultLLM LLM
)
for _, llmC := range config.LLMs {
llm := f.new(&llmC)
llms[llmC.Name] = llm
if llmC.Name == config.defaultLLM {
defaultLLM = llm
}
}
f.llms = llms
f.defaultLLM = defaultLLM
}
@@ -482,12 +497,27 @@ func (m *mockLLM) String(ctx context.Context, messages []string) (string, error)
func (m *mockLLM) EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error) {
args := m.Called(ctx, labels)
if args.Error(1) != nil {
return nil, args.Error(1)
}
return args.Get(0).([][]float32), args.Error(1)
}
func (m *mockLLM) Embedding(ctx context.Context, text string) ([]float32, error) {
args := m.Called(ctx, text)
if args.Error(1) != nil {
return nil, args.Error(1)
}
return args.Get(0).([]float32), args.Error(1)
}
func (m *mockLLM) WAV(ctx context.Context, text string, speakers []Speaker) (io.ReadCloser, error) {
args := m.Called(ctx, text, speakers)
if args.Error(1) != nil {
return nil, args.Error(1)
}
return args.Get(0).(io.ReadCloser), args.Error(1)
}

View File

@@ -18,6 +18,7 @@ package llm
import (
"context"
"encoding/json"
"io"
"github.com/pkg/errors"
oai "github.com/sashabaranov/go-openai"
@@ -31,9 +32,7 @@ import (
type openai struct {
*component.Base[Config, struct{}]
client *oai.Client
embeddingSpliter embeddingSpliter
text
}
func newOpenAI(c *Config) LLM {
@@ -42,18 +41,34 @@ func newOpenAI(c *Config) LLM {
client := oai.NewClientWithConfig(config)
embeddingSpliter := newEmbeddingSpliter(1536, 64)
base := component.New(&component.BaseConfig[Config, struct{}]{
Name: "LLM/openai",
Instance: c.Name,
Config: c,
})
return &openai{
Base: component.New(&component.BaseConfig[Config, struct{}]{
Name: "LLM/openai",
Instance: c.Name,
Config: c,
}),
client: client,
embeddingSpliter: embeddingSpliter,
Base: base,
text: &openaiText{
Base: base,
client: client,
embeddingSpliter: embeddingSpliter,
},
}
}
func (o *openai) String(ctx context.Context, messages []string) (value string, err error) {
func (o *openai) WAV(ctx context.Context, text string, speakers []Speaker) (r io.ReadCloser, err error) {
return nil, errors.New("not supported")
}
type openaiText struct {
*component.Base[Config, struct{}]
client *oai.Client
embeddingSpliter embeddingSpliter
}
func (o *openaiText) String(ctx context.Context, messages []string) (value string, err error) {
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "String")...)
defer func() { telemetry.End(ctx, err) }()
@@ -91,7 +106,7 @@ func (o *openai) String(ctx context.Context, messages []string) (value string, e
return resp.Choices[0].Message.Content, nil
}
func (o *openai) EmbeddingLabels(ctx context.Context, labels model.Labels) (value [][]float32, err error) {
func (o *openaiText) EmbeddingLabels(ctx context.Context, labels model.Labels) (value [][]float32, err error) {
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "EmbeddingLabels")...)
defer func() { telemetry.End(ctx, err) }()
@@ -117,7 +132,7 @@ func (o *openai) EmbeddingLabels(ctx context.Context, labels model.Labels) (valu
return vecs, nil
}
func (o *openai) Embedding(ctx context.Context, s string) (value []float32, err error) {
func (o *openaiText) Embedding(ctx context.Context, s string) (value []float32, err error) {
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "Embedding")...)
defer func() { telemetry.End(ctx, err) }()

View File

@@ -17,9 +17,13 @@ package rewrite
import (
"context"
"fmt"
"html/template"
"io"
"regexp"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/pkg/errors"
@@ -30,10 +34,12 @@ import (
"github.com/glidea/zenfeed/pkg/llm"
"github.com/glidea/zenfeed/pkg/llm/prompt"
"github.com/glidea/zenfeed/pkg/model"
"github.com/glidea/zenfeed/pkg/storage/object"
"github.com/glidea/zenfeed/pkg/telemetry"
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
"github.com/glidea/zenfeed/pkg/util/buffer"
"github.com/glidea/zenfeed/pkg/util/crawl"
hashutil "github.com/glidea/zenfeed/pkg/util/hash"
)
// --- Interface code block ---
@@ -68,7 +74,8 @@ func (c *Config) From(app *config.App) {
}
type Dependencies struct {
LLMFactory llm.Factory
LLMFactory llm.Factory // NOTE: String() with cache.
ObjectStorage object.Storage
}
type Rule struct {
@@ -120,32 +127,98 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen
}
// Transform.
if r.Transform != nil {
if r.Transform.ToText == nil {
return errors.New("to_text is required when transform is set")
if r.Transform != nil { //nolint:nestif
if r.Transform.ToText != nil && r.Transform.ToPodcast != nil {
return errors.New("to_text and to_podcast can not be set at same time")
}
if r.Transform.ToText == nil && r.Transform.ToPodcast == nil {
return errors.New("either to_text or to_podcast must be set when transform is set")
}
switch r.Transform.ToText.Type {
case ToTextTypePrompt:
if r.Transform.ToText.Prompt == "" {
return errors.New("to text prompt is required for prompt type")
if r.Transform.ToText != nil {
switch r.Transform.ToText.Type {
case ToTextTypePrompt:
if r.Transform.ToText.Prompt == "" {
return errors.New("to text prompt is required for prompt type")
}
tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt)
if err != nil {
return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt)
}
buf := buffer.Get()
defer buffer.Put(buf)
if err := tmpl.Execute(buf, prompt.Builtin); err != nil {
return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt)
}
r.Transform.ToText.promptRendered = buf.String()
case ToTextTypeCrawl, ToTextTypeCrawlByJina:
// No specific validation for crawl type here, as the source text itself is the URL.
default:
return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type)
}
tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt)
if err != nil {
return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt)
}
if r.Transform.ToPodcast != nil {
if len(r.Transform.ToPodcast.Speakers) == 0 {
return errors.New("at least one speaker is required for to_podcast")
}
buf := buffer.Get()
defer buffer.Put(buf)
if err := tmpl.Execute(buf, prompt.Builtin); err != nil {
return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt)
}
r.Transform.ToText.promptRendered = buf.String()
r.Transform.ToPodcast.speakers = make([]llm.Speaker, len(r.Transform.ToPodcast.Speakers))
var speakerDescs []string
var speakerNames []string
for i, s := range r.Transform.ToPodcast.Speakers {
if s.Name == "" {
return errors.New("speaker name is required")
}
if s.Voice == "" {
return errors.New("speaker voice is required")
}
r.Transform.ToPodcast.speakers[i] = llm.Speaker{Name: s.Name, Voice: s.Voice}
case ToTextTypeCrawl, ToTextTypeCrawlByJina:
// No specific validation for crawl type here, as the source text itself is the URL.
default:
return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type)
desc := s.Name
if s.Role != "" {
desc += " (" + s.Role + ")"
}
speakerDescs = append(speakerDescs, desc)
speakerNames = append(speakerNames, s.Name)
}
speakersDesc := "- " + strings.Join(speakerDescs, "\n- ")
exampleSpeaker1 := speakerNames[0]
exampleSpeaker2 := exampleSpeaker1
if len(speakerNames) > 1 {
exampleSpeaker2 = speakerNames[1]
}
promptSegments := []string{
"Please convert the following article into a podcast dialogue script.",
"The speakers are:\n" + speakersDesc,
}
if r.Transform.ToPodcast.EstimateMaximumDuration > 0 {
wordsPerMinute := 200
totalMinutes := int(r.Transform.ToPodcast.EstimateMaximumDuration.Minutes())
estimatedWords := totalMinutes * wordsPerMinute
promptSegments = append(promptSegments, fmt.Sprintf("The script should be approximately %d words to fit within a %d-minute duration. If the original content is not sufficient, the script can be shorter as appropriate.", estimatedWords, totalMinutes))
}
if r.Transform.ToPodcast.TranscriptAdditionalPrompt != "" {
promptSegments = append(promptSegments, "Additional instructions: "+r.Transform.ToPodcast.TranscriptAdditionalPrompt)
}
promptSegments = append(promptSegments,
"The output format MUST be a script where each line starts with the speaker's name followed by a colon and a space.",
"Do NOT include any other text, explanations, or formatting before or after the script.",
"Do NOT use background music in the script.",
"Do NOT include any greetings or farewells (e.g., 'Hello everyone', 'Welcome to our show', 'Goodbye').",
fmt.Sprintf("Example of the required format:\n%s: Today we are discussing the article's main points.\n%s: Let's start with the first one.", exampleSpeaker1, exampleSpeaker2),
"Now, convert the article.",
)
r.Transform.ToPodcast.transcriptPrompt = strings.Join(promptSegments, "\n\n")
r.Transform.ToPodcast.speakersDesc = speakersDesc
}
}
@@ -160,9 +233,10 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen
r.matchRE = re
// Action.
switch r.Action {
case "":
if r.Action == "" {
r.Action = ActionCreateOrUpdateLabel
}
switch r.Action {
case ActionCreateOrUpdateLabel:
if r.Label == "" {
return errors.New("label is required for create or update label action")
@@ -179,7 +253,7 @@ func (r *Rule) From(c *config.RewriteRule) {
r.If = c.If
r.SourceLabel = c.SourceLabel
r.SkipTooShortThreshold = c.SkipTooShortThreshold
if c.Transform != nil {
if c.Transform != nil { //nolint:nestif
t := &Transform{}
if c.Transform.ToText != nil {
toText := &ToText{
@@ -192,6 +266,25 @@ func (r *Rule) From(c *config.RewriteRule) {
}
t.ToText = toText
}
if c.Transform.ToPodcast != nil {
toPodcast := &ToPodcast{
LLM: c.Transform.ToPodcast.LLM,
EstimateMaximumDuration: time.Duration(c.Transform.ToPodcast.EstimateMaximumDuration),
TranscriptAdditionalPrompt: c.Transform.ToPodcast.TranscriptAdditionalPrompt,
TTSLLM: c.Transform.ToPodcast.TTSLLM,
}
if toPodcast.EstimateMaximumDuration == 0 {
toPodcast.EstimateMaximumDuration = 3 * time.Minute
}
for _, s := range c.Transform.ToPodcast.Speakers {
toPodcast.Speakers = append(toPodcast.Speakers, Speaker{
Name: s.Name,
Role: s.Role,
Voice: s.Voice,
})
}
t.ToPodcast = toPodcast
}
r.Transform = t
}
r.Match = c.Match
@@ -203,7 +296,8 @@ func (r *Rule) From(c *config.RewriteRule) {
}
type Transform struct {
ToText *ToText
ToText *ToText
ToPodcast *ToPodcast
}
type ToText struct {
@@ -220,6 +314,24 @@ type ToText struct {
promptRendered string
}
type ToPodcast struct {
LLM string
EstimateMaximumDuration time.Duration
TranscriptAdditionalPrompt string
TTSLLM string
Speakers []Speaker
transcriptPrompt string
speakersDesc string
speakers []llm.Speaker
}
type Speaker struct {
Name string
Role string
Voice string
}
type ToTextType string
const (
@@ -310,13 +422,9 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m
}
// Transform text if configured.
text := sourceText
if rule.Transform != nil && rule.Transform.ToText != nil {
transformed, err := r.transformText(ctx, rule.Transform, sourceText)
if err != nil {
return nil, errors.Wrap(err, "transform text")
}
text = transformed
text, err := r.transform(ctx, rule.Transform, sourceText)
if err != nil {
return nil, errors.Wrap(err, "transform")
}
// Check if text matches the rule.
@@ -338,18 +446,34 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m
return labels, nil
}
func (r *rewriter) transform(ctx context.Context, transform *Transform, sourceText string) (string, error) {
if transform == nil {
return sourceText, nil
}
if transform.ToText != nil {
return r.transformText(ctx, transform.ToText, sourceText)
}
if transform.ToPodcast != nil {
return r.transformPodcast(ctx, transform.ToPodcast, sourceText)
}
return sourceText, nil
}
// transformText transforms text using configured LLM or by crawling a URL.
func (r *rewriter) transformText(ctx context.Context, transform *Transform, text string) (string, error) {
switch transform.ToText.Type {
func (r *rewriter) transformText(ctx context.Context, toText *ToText, text string) (string, error) {
switch toText.Type {
case ToTextTypeCrawl:
return r.transformTextCrawl(ctx, r.crawler, text)
case ToTextTypeCrawlByJina:
return r.transformTextCrawl(ctx, r.jinaCrawler, text)
case ToTextTypePrompt:
return r.transformTextPrompt(ctx, transform, text)
return r.transformTextPrompt(ctx, toText, text)
default:
return r.transformTextPrompt(ctx, transform, text)
return r.transformTextPrompt(ctx, toText, text)
}
}
@@ -363,13 +487,13 @@ func (r *rewriter) transformTextCrawl(ctx context.Context, crawler crawl.Crawler
}
// transformTextPrompt transforms text using configured LLM.
func (r *rewriter) transformTextPrompt(ctx context.Context, transform *Transform, text string) (string, error) {
func (r *rewriter) transformTextPrompt(ctx context.Context, toText *ToText, text string) (string, error) {
// Get LLM instance.
llm := r.Dependencies().LLMFactory.Get(transform.ToText.LLM)
llm := r.Dependencies().LLMFactory.Get(toText.LLM)
// Call completion.
result, err := llm.String(ctx, []string{
transform.ToText.promptRendered,
toText.promptRendered,
text, // TODO: may place to first line to hit the model cache in different rewrite rules.
})
if err != nil {
@@ -388,6 +512,71 @@ func (r *rewriter) transformTextHack(text string) string {
return text
}
var audioKey = func(transcript, ext string) string {
hash := hashutil.Sum64(transcript)
file := strconv.FormatUint(hash, 10) + "." + ext
return "podcasts/" + file
}
func (r *rewriter) transformPodcast(ctx context.Context, toPodcast *ToPodcast, sourceText string) (url string, err error) {
transcript, err := r.generateTranscript(ctx, toPodcast, sourceText)
if err != nil {
return "", errors.Wrap(err, "generate podcast transcript")
}
audioKey := audioKey(transcript, "wav")
url, err = r.Dependencies().ObjectStorage.Get(ctx, audioKey)
switch {
case err == nil:
// May canceled at last time by reload, fast return.
return url, nil
case errors.Is(err, object.ErrNotFound):
// Not found, generate new audio.
default:
return "", errors.Wrap(err, "get audio")
}
audioStream, err := r.generateAudio(ctx, toPodcast, transcript)
if err != nil {
return "", errors.Wrap(err, "generate podcast audio")
}
defer func() {
if closeErr := audioStream.Close(); closeErr != nil {
err = errors.Wrap(err, "close audio stream")
}
}()
url, err = r.Dependencies().ObjectStorage.Put(ctx, audioKey, audioStream, "audio/wav")
if err != nil {
return "", errors.Wrap(err, "store podcast audio")
}
return url, nil
}
func (r *rewriter) generateTranscript(ctx context.Context, toPodcast *ToPodcast, sourceText string) (string, error) {
transcript, err := r.Dependencies().LLMFactory.Get(toPodcast.LLM).
String(ctx, []string{toPodcast.transcriptPrompt, sourceText})
if err != nil {
return "", errors.Wrap(err, "llm completion")
}
return toPodcast.speakersDesc +
"\n\nFollowed by the actual dialogue script:\n" +
transcript, nil
}
func (r *rewriter) generateAudio(ctx context.Context, toPodcast *ToPodcast, transcript string) (io.ReadCloser, error) {
audioStream, err := r.Dependencies().LLMFactory.Get(toPodcast.TTSLLM).
WAV(ctx, transcript, toPodcast.speakers)
if err != nil {
return nil, errors.Wrap(err, "calling tts llm")
}
return audioStream, nil
}
type mockRewriter struct {
component.Mock
}

View File

@@ -2,6 +2,8 @@ package rewrite
import (
"context"
"io"
"strings"
"testing"
. "github.com/onsi/gomega"
@@ -12,6 +14,7 @@ import (
"github.com/glidea/zenfeed/pkg/component"
"github.com/glidea/zenfeed/pkg/llm"
"github.com/glidea/zenfeed/pkg/model"
"github.com/glidea/zenfeed/pkg/storage/object"
"github.com/glidea/zenfeed/pkg/test"
)
@@ -19,8 +22,9 @@ func TestLabels(t *testing.T) {
RegisterTestingT(t)
type givenDetail struct {
config *Config
llmMock func(m *mock.Mock)
config *Config
llmMock func(m *mock.Mock)
objectStorageMock func(m *mock.Mock)
}
type whenDetail struct {
inputLabels model.Labels
@@ -173,7 +177,7 @@ func TestLabels(t *testing.T) {
},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform text: llm completion: LLM failed"),
err: errors.New("transform: llm completion: LLM failed"),
isErr: true,
},
},
@@ -220,22 +224,163 @@ func TestLabels(t *testing.T) {
isErr: false,
},
},
{
Scenario: "Successfully generate podcast from content",
Given: "a rule to convert content to a podcast with all dependencies mocked to succeed",
When: "processing labels with content to be converted to a podcast",
Then: "should return labels with a new podcast_url label",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{
LLM: "mock-llm-transcript",
TTSLLM: "mock-llm-tts",
Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}},
},
},
Action: ActionCreateOrUpdateLabel,
Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("This is the podcast script.", nil).Once()
m.On("WAV", mock.Anything, mock.Anything, mock.AnythingOfType("[]llm.Speaker")).
Return(io.NopCloser(strings.NewReader("fake audio data")), nil).Once()
},
objectStorageMock: func(m *mock.Mock) {
m.On("Put", mock.Anything, mock.AnythingOfType("string"), mock.Anything, "audio/wav").
Return("http://storage.example.com/podcast.wav", nil).Once()
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
},
},
WhenDetail: whenDetail{
inputLabels: model.Labels{
{Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."},
},
},
ThenExpected: thenExpected{
outputLabels: model.Labels{
{Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."},
{Key: "podcast_url", Value: "http://storage.example.com/podcast.wav"},
},
isErr: false,
},
},
{
Scenario: "Fail podcast generation due to transcription LLM error",
Given: "a rule to convert content to a podcast, but the transcription LLM is mocked to fail",
When: "processing labels",
Then: "should return an error related to transcription failure",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
},
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("", errors.New("transcript failed")).Once()
},
},
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform: generate podcast transcript: llm completion: transcript failed"),
isErr: true,
},
},
{
Scenario: "Fail podcast generation due to TTS LLM error",
Given: "a rule to convert content to a podcast, but the TTS LLM is mocked to fail",
When: "processing labels",
Then: "should return an error related to TTS failure",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
},
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once()
m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("tts failed")).Once()
},
objectStorageMock: func(m *mock.Mock) {
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
},
},
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform: generate podcast audio: calling tts llm: tts failed"),
isErr: true,
},
},
{
Scenario: "Fail podcast generation due to object storage error",
Given: "a rule to convert content to a podcast, but object storage is mocked to fail",
When: "processing labels",
Then: "should return an error related to storage failure",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
},
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once()
m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(io.NopCloser(strings.NewReader("fake audio")), nil).Once()
},
objectStorageMock: func(m *mock.Mock) {
m.On("Put", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return("", errors.New("storage failed")).Once()
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
},
},
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform: store podcast audio: storage failed"),
isErr: true,
},
},
}
for _, tt := range tests {
t.Run(tt.Scenario, func(t *testing.T) {
// Given.
var mockLLMFactory llm.Factory
var mockInstance *mock.Mock // Store the mock instance for assertion
// Create mock factory and capture the mock.Mock instance.
mockOption := component.MockOption(func(m *mock.Mock) {
mockInstance = m // Capture the mock instance.
var mockLLMInstance *mock.Mock
llmMockOption := component.MockOption(func(m *mock.Mock) {
mockLLMInstance = m
if tt.GivenDetail.llmMock != nil {
tt.GivenDetail.llmMock(m)
}
})
mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, mockOption) // Use the factory directly with the option
mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, llmMockOption)
Expect(err).NotTo(HaveOccurred())
var mockObjectStorage object.Storage
var mockObjectStorageInstance *mock.Mock
objectStorageMockOption := component.MockOption(func(m *mock.Mock) {
mockObjectStorageInstance = m
if tt.GivenDetail.objectStorageMock != nil {
tt.GivenDetail.objectStorageMock(m)
}
})
mockObjectStorageFactory := object.NewFactory(objectStorageMockOption)
mockObjectStorage, err = mockObjectStorageFactory.New("test", nil, object.Dependencies{})
Expect(err).NotTo(HaveOccurred())
// Manually validate config to compile regex and render templates.
@@ -252,7 +397,8 @@ func TestLabels(t *testing.T) {
Instance: "test",
Config: tt.GivenDetail.config,
Dependencies: Dependencies{
LLMFactory: mockLLMFactory, // Pass the mock factory
LLMFactory: mockLLMFactory, // Pass the mock factory
ObjectStorage: mockObjectStorage,
},
}),
}
@@ -280,10 +426,12 @@ func TestLabels(t *testing.T) {
Expect(outputLabels).To(Equal(tt.ThenExpected.outputLabels))
}
// Verify LLM calls if stubs were provided.
if tt.GivenDetail.llmMock != nil && mockInstance != nil {
// Assert expectations on the captured mock instance.
mockInstance.AssertExpectations(t)
// Verify mock calls if stubs were provided.
if tt.GivenDetail.llmMock != nil && mockLLMInstance != nil {
mockLLMInstance.AssertExpectations(t)
}
if tt.GivenDetail.objectStorageMock != nil && mockObjectStorageInstance != nil {
mockObjectStorageInstance.AssertExpectations(t)
}
})
}

View File

@@ -216,6 +216,10 @@ func (m *manager) reload(config *Config) (err error) {
func (m *manager) runOrRestartScrapers(config *Config, newScrapers map[string]scraper.Scraper) error {
for i := range config.Scrapers {
c := &config.Scrapers[i]
if err := c.Validate(); err != nil {
return errors.Wrapf(err, "validate scraper %s", c.Name)
}
if err := m.runOrRestartScraper(c, newScrapers); err != nil {
return errors.Wrapf(err, "run or restart scraper %s", c.Name)
}

View File

@@ -69,6 +69,11 @@ func (c *Config) Validate() error {
if c.Name == "" {
return errors.New("name cannot be empty")
}
if c.RSS != nil {
if err := c.RSS.Validate(); err != nil {
return errors.Wrap(err, "invalid RSS config")
}
}
return nil
}

View File

@@ -244,7 +244,7 @@ func TestNew(t *testing.T) {
WhenDetail: whenDetail{},
ThenExpected: thenExpected{
isErr: true,
wantErrMsg: "creating source: invalid RSS config: URL must be a valid HTTP/HTTPS URL", // Error from newRSSReader via newReader
wantErrMsg: "invalid RSS config: URL must be a valid HTTP/HTTPS URL", // Error from newRSSReader via newReader
},
},
{
@@ -264,7 +264,7 @@ func TestNew(t *testing.T) {
WhenDetail: whenDetail{},
ThenExpected: thenExpected{
isErr: true,
wantErrMsg: "creating source: source not supported", // Error from newReader
wantErrMsg: "source not supported", // Error from newReader
},
},
}

View File

@@ -0,0 +1,235 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package object
import (
"context"
"io"
"net/url"
"reflect"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/pkg/errors"
"github.com/glidea/zenfeed/pkg/component"
"github.com/glidea/zenfeed/pkg/config"
"github.com/glidea/zenfeed/pkg/telemetry"
"github.com/glidea/zenfeed/pkg/telemetry/log"
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
)
// --- Interface code block ---
type Storage interface {
component.Component
config.Watcher
Put(ctx context.Context, key string, body io.Reader, contentType string) (url string, err error)
Get(ctx context.Context, key string) (url string, err error)
}
var ErrNotFound = errors.New("not found")
type Config struct {
Endpoint string
AccessKeyID string
SecretAccessKey string
Bucket string
BucketURL string
}
func (c *Config) Validate() error {
if c.Endpoint == "" {
return errors.New("endpoint is required")
}
if c.AccessKeyID == "" {
return errors.New("access key id is required")
}
if c.SecretAccessKey == "" {
return errors.New("secret access key is required")
}
if c.Bucket == "" {
return errors.New("bucket is required")
}
if c.BucketURL == "" {
return errors.New("bucket url is required")
}
return nil
}
func (c *Config) From(app *config.App) *Config {
*c = Config{
Endpoint: app.Storage.Object.Endpoint,
AccessKeyID: app.Storage.Object.AccessKeyID,
SecretAccessKey: app.Storage.Object.SecretAccessKey,
Bucket: app.Storage.Object.Bucket,
BucketURL: app.Storage.Object.BucketURL,
}
return c
}
type Dependencies struct{}
// --- Factory code block ---
type Factory component.Factory[Storage, config.App, Dependencies]
func NewFactory(mockOn ...component.MockOption) Factory {
if len(mockOn) > 0 {
return component.FactoryFunc[Storage, config.App, Dependencies](
func(instance string, config *config.App, dependencies Dependencies) (Storage, error) {
m := &mockStorage{}
component.MockOptions(mockOn).Apply(&m.Mock)
return m, nil
},
)
}
return component.FactoryFunc[Storage, config.App, Dependencies](new)
}
func new(instance string, app *config.App, dependencies Dependencies) (Storage, error) {
config := &Config{}
config.From(app)
if err := config.Validate(); err != nil {
return nil, errors.Wrap(err, "validate config")
}
client, err := minio.New(config.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(config.AccessKeyID, config.SecretAccessKey, ""),
Secure: true,
})
if err != nil {
return nil, errors.Wrap(err, "new minio client")
}
u, err := url.Parse(config.BucketURL)
if err != nil {
return nil, errors.Wrap(err, "parse public url")
}
return &s3{
Base: component.New(&component.BaseConfig[Config, Dependencies]{
Name: "ObjectStorage",
Instance: instance,
Config: config,
Dependencies: dependencies,
}),
client: client,
bucketURL: u,
}, nil
}
// --- Implementation code block ---
type s3 struct {
*component.Base[Config, Dependencies]
client *minio.Client
bucketURL *url.URL
}
func (s *s3) Put(ctx context.Context, key string, body io.Reader, contentType string) (publicURL string, err error) {
ctx = telemetry.StartWith(ctx, append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Put")...)
defer func() { telemetry.End(ctx, err) }()
bucket := s.Config().Bucket
if _, err := s.client.PutObject(ctx, bucket, key, body, -1, minio.PutObjectOptions{
ContentType: contentType,
}); err != nil {
return "", errors.Wrap(err, "put object")
}
return s.bucketURL.JoinPath(key).String(), nil
}
func (s *s3) Get(ctx context.Context, key string) (publicURL string, err error) {
ctx = telemetry.StartWith(ctx, append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Get")...)
defer func() { telemetry.End(ctx, err) }()
bucket := s.Config().Bucket
if _, err := s.client.StatObject(ctx, bucket, key, minio.StatObjectOptions{}); err != nil {
errResponse := minio.ToErrorResponse(err)
if errResponse.Code == minio.NoSuchKey {
return "", ErrNotFound
}
return "", errors.Wrap(err, "stat object")
}
return s.bucketURL.JoinPath(key).String(), nil
}
func (s *s3) Reload(app *config.App) (err error) {
ctx := telemetry.StartWith(s.Context(), append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Reload")...)
defer func() { telemetry.End(ctx, err) }()
newConfig := &Config{}
newConfig.From(app)
if err := newConfig.Validate(); err != nil {
return errors.Wrap(err, "validate config")
}
if reflect.DeepEqual(s.Config(), newConfig) {
log.Debug(ctx, "object storage config not changed")
return nil
}
client, err := minio.New(newConfig.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(newConfig.AccessKeyID, newConfig.SecretAccessKey, ""),
Secure: true,
})
if err != nil {
return errors.Wrap(err, "new minio client")
}
u, err := url.Parse(newConfig.BucketURL)
if err != nil {
return errors.Wrap(err, "parse public url")
}
s.client = client
s.bucketURL = u
s.SetConfig(newConfig)
log.Info(ctx, "object storage reloaded")
return nil
}
// --- Mock code block ---
type mockStorage struct {
component.Mock
}
func (m *mockStorage) Put(ctx context.Context, key string, body io.Reader, contentType string) (string, error) {
args := m.Called(ctx, key, body, contentType)
return args.String(0), args.Error(1)
}
func (m *mockStorage) Get(ctx context.Context, key string) (string, error) {
args := m.Called(ctx, key)
return args.String(0), args.Error(1)
}
func (m *mockStorage) Reload(app *config.App) error {
args := m.Called(app)
return args.Error(0)
}

View File

@@ -122,6 +122,32 @@ func ReadUint32(r io.Reader) (uint32, error) {
return binary.LittleEndian.Uint32(b), nil
}
// WriteUint16 writes a uint16 using a pooled buffer.
func WriteUint16(w io.Writer, v uint16) error {
bp := smallBufPool.Get().(*[]byte)
defer smallBufPool.Put(bp)
b := *bp
binary.LittleEndian.PutUint16(b, v)
_, err := w.Write(b[:2])
return err
}
// ReadUint16 reads a uint16 using a pooled buffer.
func ReadUint16(r io.Reader) (uint16, error) {
bp := smallBufPool.Get().(*[]byte)
defer smallBufPool.Put(bp)
b := (*bp)[:2]
// Read exactly 2 bytes into the slice.
if _, err := io.ReadFull(r, b); err != nil {
return 0, errors.Wrap(err, "read uint16")
}
return binary.LittleEndian.Uint16(b), nil
}
// WriteFloat32 writes a float32 using a pooled buffer.
func WriteFloat32(w io.Writer, v float32) error {
return WriteUint32(w, math.Float32bits(v))

100
pkg/util/wav/wav.go Normal file
View File

@@ -0,0 +1,100 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package wav
import (
"io"
"github.com/pkg/errors"
binaryutil "github.com/glidea/zenfeed/pkg/util/binary"
)
// Header contains the WAV header information.
type Header struct {
SampleRate uint32
BitDepth uint16
NumChannels uint16
}
// WriteHeader writes the WAV header to a writer.
// pcmDataSize is the size of the raw PCM data.
func WriteHeader(w io.Writer, h *Header, pcmDataSize uint32) error {
// RIFF Header.
if err := writeRIFFHeader(w, pcmDataSize); err != nil {
return errors.Wrap(err, "write RIFF header")
}
// fmt chunk.
if err := writeFMTChunk(w, h); err != nil {
return errors.Wrap(err, "write fmt chunk")
}
// data chunk.
if _, err := w.Write([]byte("data")); err != nil {
return errors.Wrap(err, "write data chunk marker")
}
if err := binaryutil.WriteUint32(w, pcmDataSize); err != nil {
return errors.Wrap(err, "write pcm data size")
}
return nil
}
func writeRIFFHeader(w io.Writer, pcmDataSize uint32) error {
if _, err := w.Write([]byte("RIFF")); err != nil {
return errors.Wrap(err, "write RIFF")
}
if err := binaryutil.WriteUint32(w, uint32(36+pcmDataSize)); err != nil {
return errors.Wrap(err, "write file size")
}
if _, err := w.Write([]byte("WAVE")); err != nil {
return errors.Wrap(err, "write WAVE")
}
return nil
}
func writeFMTChunk(w io.Writer, h *Header) error {
if _, err := w.Write([]byte("fmt ")); err != nil {
return errors.Wrap(err, "write fmt")
}
if err := binaryutil.WriteUint32(w, uint32(16)); err != nil { // PCM chunk size.
return errors.Wrap(err, "write pcm chunk size")
}
if err := binaryutil.WriteUint16(w, uint16(1)); err != nil { // PCM format.
return errors.Wrap(err, "write pcm format")
}
if err := binaryutil.WriteUint16(w, h.NumChannels); err != nil {
return errors.Wrap(err, "write num channels")
}
if err := binaryutil.WriteUint32(w, h.SampleRate); err != nil {
return errors.Wrap(err, "write sample rate")
}
byteRate := h.SampleRate * uint32(h.NumChannels) * uint32(h.BitDepth) / 8
if err := binaryutil.WriteUint32(w, byteRate); err != nil {
return errors.Wrap(err, "write byte rate")
}
blockAlign := h.NumChannels * h.BitDepth / 8
if err := binaryutil.WriteUint16(w, blockAlign); err != nil {
return errors.Wrap(err, "write block align")
}
if err := binaryutil.WriteUint16(w, h.BitDepth); err != nil {
return errors.Wrap(err, "write bit depth")
}
return nil
}

161
pkg/util/wav/wav_test.go Normal file
View File

@@ -0,0 +1,161 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package wav
import (
"bytes"
"testing"
. "github.com/onsi/gomega"
"github.com/glidea/zenfeed/pkg/test"
)
func TestWriteHeader(t *testing.T) {
RegisterTestingT(t)
type givenDetail struct{}
type whenDetail struct {
header *Header
pcmDataSize uint32
}
type thenExpected struct {
expectedBytes []byte
expectError bool
}
tests := []test.Case[givenDetail, whenDetail, thenExpected]{
{
Scenario: "Standard CD quality audio",
Given: "a header for CD quality audio and a non-zero data size",
When: "writing the header",
Then: "should produce a valid 44-byte WAV header and no error",
GivenDetail: givenDetail{},
WhenDetail: whenDetail{
header: &Header{
SampleRate: 44100,
BitDepth: 16,
NumChannels: 2,
},
pcmDataSize: 176400,
},
ThenExpected: thenExpected{
expectedBytes: []byte{
'R', 'I', 'F', 'F',
0x34, 0xB1, 0x02, 0x00, // ChunkSize = 36 + 176400 = 176436
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16
0x01, 0x00, // AudioFormat = 1 (PCM)
0x02, 0x00, // NumChannels = 2
0x44, 0xAC, 0x00, 0x00, // SampleRate = 44100
0x10, 0xB1, 0x02, 0x00, // ByteRate = 176400
0x04, 0x00, // BlockAlign = 4
0x10, 0x00, // BitsPerSample = 16
'd', 'a', 't', 'a',
0x10, 0xB1, 0x02, 0x00, // Subchunk2Size = 176400
},
expectError: false,
},
},
{
Scenario: "Mono audio for speech",
Given: "a header for mono speech audio and a non-zero data size",
When: "writing the header",
Then: "should produce a valid 44-byte WAV header and no error",
GivenDetail: givenDetail{},
WhenDetail: whenDetail{
header: &Header{
SampleRate: 16000,
BitDepth: 16,
NumChannels: 1,
},
pcmDataSize: 32000,
},
ThenExpected: thenExpected{
expectedBytes: []byte{
'R', 'I', 'F', 'F',
0x24, 0x7D, 0x00, 0x00, // ChunkSize = 36 + 32000 = 32036
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16
0x01, 0x00, // AudioFormat = 1
0x01, 0x00, // NumChannels = 1
0x80, 0x3E, 0x00, 0x00, // SampleRate = 16000
0x00, 0x7D, 0x00, 0x00, // ByteRate = 32000
0x02, 0x00, // BlockAlign = 2
0x10, 0x00, // BitsPerSample = 16
'd', 'a', 't', 'a',
0x00, 0x7D, 0x00, 0x00, // Subchunk2Size = 32000
},
expectError: false,
},
},
{
Scenario: "8-bit mono audio with zero data size",
Given: "a header for 8-bit mono audio and a zero data size",
When: "writing the header for an empty file",
Then: "should produce a valid 44-byte WAV header with data size 0",
GivenDetail: givenDetail{},
WhenDetail: whenDetail{
header: &Header{
SampleRate: 8000,
BitDepth: 8,
NumChannels: 1,
},
pcmDataSize: 0,
},
ThenExpected: thenExpected{
expectedBytes: []byte{
'R', 'I', 'F', 'F',
0x24, 0x00, 0x00, 0x00, // ChunkSize = 36 + 0 = 36
'W', 'A', 'V', 'E',
'f', 'm', 't', ' ',
0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16
0x01, 0x00, // AudioFormat = 1
0x01, 0x00, // NumChannels = 1
0x40, 0x1F, 0x00, 0x00, // SampleRate = 8000
0x40, 0x1F, 0x00, 0x00, // ByteRate = 8000
0x01, 0x00, // BlockAlign = 1
0x08, 0x00, // BitsPerSample = 8
'd', 'a', 't', 'a',
0x00, 0x00, 0x00, 0x00, // Subchunk2Size = 0
},
expectError: false,
},
},
}
for _, tt := range tests {
t.Run(tt.Scenario, func(t *testing.T) {
// Given.
var buf bytes.Buffer
// When.
err := WriteHeader(&buf, tt.WhenDetail.header, tt.WhenDetail.pcmDataSize)
// Then.
if tt.ThenExpected.expectError {
Expect(err).To(HaveOccurred())
} else {
Expect(err).NotTo(HaveOccurred())
Expect(buf.Bytes()).To(Equal(tt.ThenExpected.expectedBytes))
}
})
}
}