add podcast
This commit is contained in:
@@ -90,6 +90,7 @@ type LLM struct {
|
||||
APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty" desc:"The API key of the LLM. It is required when api.llm is set."`
|
||||
Model string `yaml:"model,omitempty" json:"model,omitempty" desc:"The model of the LLM. e.g. gpt-4o-mini. Can not be empty with embedding_model at same time when api.llm is set."`
|
||||
EmbeddingModel string `yaml:"embedding_model,omitempty" json:"embedding_model,omitempty" desc:"The embedding model of the LLM. e.g. text-embedding-3-small. Can not be empty with model at same time when api.llm is set. NOTE: Once used, do not modify it directly, instead, add a new LLM configuration."`
|
||||
TTSModel string `yaml:"tts_model,omitempty" json:"tts_model,omitempty" desc:"The TTS model of the LLM."`
|
||||
Temperature float32 `yaml:"temperature,omitempty" json:"temperature,omitempty" desc:"The temperature (0-2) of the LLM. Default: 0.0"`
|
||||
}
|
||||
|
||||
@@ -101,8 +102,9 @@ type Scrape struct {
|
||||
}
|
||||
|
||||
type Storage struct {
|
||||
Dir string `yaml:"dir,omitempty" json:"dir,omitempty" desc:"The base directory of the all storages. Default: ./data. It can not be changed after the app is running."`
|
||||
Feed FeedStorage `yaml:"feed,omitempty" json:"feed,omitempty" desc:"The feed storage config."`
|
||||
Dir string `yaml:"dir,omitempty" json:"dir,omitempty" desc:"The base directory of the all storages. Default: ./data. It can not be changed after the app is running."`
|
||||
Feed FeedStorage `yaml:"feed,omitempty" json:"feed,omitempty" desc:"The feed storage config."`
|
||||
Object ObjectStorage `yaml:"object,omitempty" json:"object,omitempty" desc:"The object storage config."`
|
||||
}
|
||||
|
||||
type FeedStorage struct {
|
||||
@@ -113,6 +115,14 @@ type FeedStorage struct {
|
||||
BlockDuration timeutil.Duration `yaml:"block_duration,omitempty" json:"block_duration,omitempty" desc:"How long to keep the feed storage block. Block is time-based, like Prometheus TSDB Block. Default: 25h"`
|
||||
}
|
||||
|
||||
type ObjectStorage struct {
|
||||
Endpoint string `yaml:"endpoint,omitempty" json:"endpoint,omitempty" desc:"The endpoint of the object storage."`
|
||||
AccessKeyID string `yaml:"access_key_id,omitempty" json:"access_key_id,omitempty" desc:"The access key id of the object storage."`
|
||||
SecretAccessKey string `yaml:"secret_access_key,omitempty" json:"secret_access_key,omitempty" desc:"The secret access key of the object storage."`
|
||||
Bucket string `yaml:"bucket,omitempty" json:"bucket,omitempty" desc:"The bucket of the object storage."`
|
||||
BucketURL string `yaml:"bucket_url,omitempty" json:"bucket_url,omitempty" desc:"The public URL of the object storage bucket."`
|
||||
}
|
||||
|
||||
type ScrapeSource struct {
|
||||
Interval timeutil.Duration `yaml:"interval,omitempty" json:"interval,omitempty" desc:"How often to scrape this source. Default: global interval"`
|
||||
Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the source. It is required."`
|
||||
@@ -137,7 +147,8 @@ type RewriteRule struct {
|
||||
}
|
||||
|
||||
type RewriteRuleTransform struct {
|
||||
ToText *RewriteRuleTransformToText `yaml:"to_text,omitempty" json:"to_text,omitempty" desc:"The transform config to transform the source text to text."`
|
||||
ToText *RewriteRuleTransformToText `yaml:"to_text,omitempty" json:"to_text,omitempty" desc:"The transform config to transform the source text to text."`
|
||||
ToPodcast *RewriteRuleTransformToPodcast `yaml:"to_podcast,omitempty" json:"to_podcast,omitempty" desc:"The transform config to transform the source text to podcast."`
|
||||
}
|
||||
|
||||
type RewriteRuleTransformToText struct {
|
||||
@@ -146,6 +157,20 @@ type RewriteRuleTransformToText struct {
|
||||
Prompt string `yaml:"prompt,omitempty" json:"prompt,omitempty" desc:"The prompt to transform the source text. The source text will be injected into the prompt above. And you can use go template syntax to refer some built-in prompts, like {{ .summary }}. Available built-in prompts: category, tags, score, comment_confucius, summary, summary_html_snippet."`
|
||||
}
|
||||
|
||||
type RewriteRuleTransformToPodcast struct {
|
||||
LLM string `yaml:"llm,omitempty" json:"llm,omitempty" desc:"The LLM name to use. Default is the default LLM in llms section."`
|
||||
EstimateMaximumDuration timeutil.Duration `yaml:"estimate_maximum_duration,omitempty" json:"estimate_maximum_duration,omitempty" desc:"The estimated maximum duration of the podcast. It will affect the length of the generated transcript. e.g. 5m. Default is 5m."`
|
||||
TranscriptAdditionalPrompt string `yaml:"transcript_additional_prompt,omitempty" json:"transcript_additional_prompt,omitempty" desc:"The additional prompt to add to the transcript. It is optional."`
|
||||
TTSLLM string `yaml:"tts_llm,omitempty" json:"tts_llm,omitempty" desc:"The LLM name to use for TTS. Only supports gemini now. Default is the default LLM in llms section."`
|
||||
Speakers []RewriteRuleTransformToPodcastSpeaker `yaml:"speakers,omitempty" json:"speakers,omitempty" desc:"The speakers to use. It is required, at least one speaker is needed."`
|
||||
}
|
||||
|
||||
type RewriteRuleTransformToPodcastSpeaker struct {
|
||||
Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the speaker. It is required."`
|
||||
Role string `yaml:"role,omitempty" json:"role,omitempty" desc:"The role description of the speaker. You can think of it as a character setting."`
|
||||
Voice string `yaml:"voice,omitempty" json:"voice,omitempty" desc:"The voice of the speaker. It is required."`
|
||||
}
|
||||
|
||||
type SchedulsRule struct {
|
||||
Name string `yaml:"name,omitempty" json:"name,omitempty" desc:"The name of the rule. It is required."`
|
||||
Query string `yaml:"query,omitempty" json:"query,omitempty" desc:"The semantic query to get the feeds. NOTE it is optional"`
|
||||
|
||||
248
pkg/llm/gemini.go
Normal file
248
pkg/llm/gemini.go
Normal file
@@ -0,0 +1,248 @@
|
||||
// Copyright (C) 2025 wangyusong
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
oai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/glidea/zenfeed/pkg/component"
|
||||
"github.com/glidea/zenfeed/pkg/telemetry"
|
||||
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
|
||||
"github.com/glidea/zenfeed/pkg/util/wav"
|
||||
)
|
||||
|
||||
type gemini struct {
|
||||
*component.Base[Config, struct{}]
|
||||
text
|
||||
hc *http.Client
|
||||
|
||||
embeddingSpliter embeddingSpliter
|
||||
}
|
||||
|
||||
func newGemini(c *Config) LLM {
|
||||
config := oai.DefaultConfig(c.APIKey)
|
||||
config.BaseURL = filepath.Join(c.Endpoint, "openai") // OpenAI compatible endpoint.
|
||||
client := oai.NewClientWithConfig(config)
|
||||
embeddingSpliter := newEmbeddingSpliter(1536, 64)
|
||||
|
||||
base := component.New(&component.BaseConfig[Config, struct{}]{
|
||||
Name: "LLM/gemini",
|
||||
Instance: c.Name,
|
||||
Config: c,
|
||||
})
|
||||
|
||||
return &gemini{
|
||||
Base: base,
|
||||
text: &openaiText{
|
||||
Base: base,
|
||||
client: client,
|
||||
},
|
||||
hc: &http.Client{},
|
||||
embeddingSpliter: embeddingSpliter,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gemini) WAV(ctx context.Context, text string, speakers []Speaker) (r io.ReadCloser, err error) {
|
||||
ctx = telemetry.StartWith(ctx, append(g.TelemetryLabels(), telemetrymodel.KeyOperation, "WAV")...)
|
||||
defer func() { telemetry.End(ctx, err) }()
|
||||
|
||||
if g.Config().TTSModel == "" {
|
||||
return nil, errors.New("tts model is not set")
|
||||
}
|
||||
|
||||
reqPayload, err := buildWAVRequestPayload(text, speakers)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "build wav request payload")
|
||||
}
|
||||
|
||||
pcmData, err := g.doWAVRequest(ctx, reqPayload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "do wav request")
|
||||
}
|
||||
|
||||
return streamWAV(pcmData), nil
|
||||
}
|
||||
|
||||
func (g *gemini) doWAVRequest(ctx context.Context, reqPayload *geminiRequest) ([]byte, error) {
|
||||
config := g.Config()
|
||||
body, err := json.Marshal(reqPayload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal tts request")
|
||||
}
|
||||
|
||||
url := config.Endpoint + "/models/" + config.TTSModel + ":generateContent"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "new tts request")
|
||||
}
|
||||
req.Header.Set("x-goog-api-key", config.APIKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := g.hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "do tts request")
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errMsg, _ := io.ReadAll(resp.Body)
|
||||
|
||||
return nil, errors.Errorf("tts request failed with status %d: %s", resp.StatusCode, string(errMsg))
|
||||
}
|
||||
|
||||
var ttsResp geminiResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&ttsResp); err != nil {
|
||||
return nil, errors.Wrap(err, "decode tts response")
|
||||
}
|
||||
if len(ttsResp.Candidates) == 0 || len(ttsResp.Candidates[0].Content.Parts) == 0 || ttsResp.Candidates[0].Content.Parts[0].InlineData == nil {
|
||||
return nil, errors.New("no audio data in tts response")
|
||||
}
|
||||
|
||||
audioDataB64 := ttsResp.Candidates[0].Content.Parts[0].InlineData.Data
|
||||
pcmData, err := base64.StdEncoding.DecodeString(audioDataB64)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode base64")
|
||||
}
|
||||
|
||||
return pcmData, nil
|
||||
}
|
||||
|
||||
func buildWAVRequestPayload(text string, speakers []Speaker) (*geminiRequest, error) {
|
||||
reqPayload := geminiRequest{
|
||||
Contents: []*geminiRequestContent{{Parts: []*geminiRequestPart{{Text: text}}}},
|
||||
Config: &geminiRequestConfig{
|
||||
ResponseModalities: []string{"AUDIO"},
|
||||
SpeechConfig: &geminiRequestSpeechConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
switch len(speakers) {
|
||||
case 0:
|
||||
return nil, errors.New("no speakers")
|
||||
case 1:
|
||||
reqPayload.Config.SpeechConfig.VoiceConfig = &geminiRequestVoiceConfig{
|
||||
PrebuiltVoiceConfig: &geminiRequestPrebuiltVoiceConfig{VoiceName: speakers[0].Voice},
|
||||
}
|
||||
default:
|
||||
multiSpeakerConfig := &geminiRequestMultiSpeakerVoiceConfig{}
|
||||
for _, s := range speakers {
|
||||
multiSpeakerConfig.SpeakerVoiceConfigs = append(multiSpeakerConfig.SpeakerVoiceConfigs, &geminiRequestSpeakerVoiceConfig{
|
||||
Speaker: s.Name,
|
||||
VoiceConfig: &geminiRequestVoiceConfig{
|
||||
PrebuiltVoiceConfig: &geminiRequestPrebuiltVoiceConfig{VoiceName: s.Voice},
|
||||
},
|
||||
})
|
||||
}
|
||||
reqPayload.Config.SpeechConfig.MultiSpeakerVoiceConfig = multiSpeakerConfig
|
||||
}
|
||||
|
||||
return &reqPayload, nil
|
||||
}
|
||||
|
||||
func streamWAV(pcmData []byte) io.ReadCloser {
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
go func() {
|
||||
defer func() { _ = pipeWriter.Close() }()
|
||||
if err := wav.WriteHeader(pipeWriter, geminiWavHeader, uint32(len(pcmData))); err != nil {
|
||||
pipeWriter.CloseWithError(errors.Wrap(err, "write wav header"))
|
||||
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(pipeWriter, bytes.NewReader(pcmData)); err != nil {
|
||||
pipeWriter.CloseWithError(errors.Wrap(err, "write pcm data"))
|
||||
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return pipeReader
|
||||
}
|
||||
|
||||
var geminiWavHeader = &wav.Header{
|
||||
SampleRate: 24000,
|
||||
BitDepth: 16,
|
||||
NumChannels: 1,
|
||||
}
|
||||
|
||||
type geminiRequest struct {
|
||||
Contents []*geminiRequestContent `json:"contents"`
|
||||
Config *geminiRequestConfig `json:"generationConfig"`
|
||||
}
|
||||
|
||||
type geminiRequestContent struct {
|
||||
Parts []*geminiRequestPart `json:"parts"`
|
||||
}
|
||||
|
||||
type geminiRequestPart struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type geminiRequestConfig struct {
|
||||
ResponseModalities []string `json:"responseModalities"`
|
||||
SpeechConfig *geminiRequestSpeechConfig `json:"speechConfig"`
|
||||
}
|
||||
|
||||
type geminiRequestSpeechConfig struct {
|
||||
VoiceConfig *geminiRequestVoiceConfig `json:"voiceConfig,omitempty"`
|
||||
MultiSpeakerVoiceConfig *geminiRequestMultiSpeakerVoiceConfig `json:"multiSpeakerVoiceConfig,omitempty"`
|
||||
}
|
||||
|
||||
type geminiRequestVoiceConfig struct {
|
||||
PrebuiltVoiceConfig *geminiRequestPrebuiltVoiceConfig `json:"prebuiltVoiceConfig,omitempty"`
|
||||
}
|
||||
|
||||
type geminiRequestPrebuiltVoiceConfig struct {
|
||||
VoiceName string `json:"voiceName,omitempty"`
|
||||
}
|
||||
|
||||
type geminiRequestMultiSpeakerVoiceConfig struct {
|
||||
SpeakerVoiceConfigs []*geminiRequestSpeakerVoiceConfig `json:"speakerVoiceConfigs,omitempty"`
|
||||
}
|
||||
|
||||
type geminiRequestSpeakerVoiceConfig struct {
|
||||
Speaker string `json:"speaker,omitempty"`
|
||||
VoiceConfig *geminiRequestVoiceConfig `json:"voiceConfig,omitempty"`
|
||||
}
|
||||
|
||||
type geminiResponse struct {
|
||||
Candidates []*geminiResponseCandidate `json:"candidates"`
|
||||
}
|
||||
|
||||
type geminiResponseCandidate struct {
|
||||
Content *geminiResponseContent `json:"content"`
|
||||
}
|
||||
|
||||
type geminiResponseContent struct {
|
||||
Parts []*geminiResponsePart `json:"parts"`
|
||||
}
|
||||
|
||||
type geminiResponsePart struct {
|
||||
InlineData *geminiResponseInlineData `json:"inlineData"`
|
||||
}
|
||||
|
||||
type geminiResponseInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"` // Base64 encoded.
|
||||
}
|
||||
@@ -18,6 +18,7 @@ package llm
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -42,19 +43,33 @@ import (
|
||||
// --- Interface code block ---
|
||||
type LLM interface {
|
||||
component.Component
|
||||
text
|
||||
audio
|
||||
}
|
||||
|
||||
type text interface {
|
||||
String(ctx context.Context, messages []string) (string, error)
|
||||
EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error)
|
||||
Embedding(ctx context.Context, text string) ([]float32, error)
|
||||
}
|
||||
|
||||
type audio interface {
|
||||
WAV(ctx context.Context, text string, speakers []Speaker) (io.ReadCloser, error)
|
||||
}
|
||||
|
||||
type Speaker struct {
|
||||
Name string
|
||||
Voice string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
Default bool
|
||||
Provider ProviderType
|
||||
Endpoint string
|
||||
APIKey string
|
||||
Model, EmbeddingModel string
|
||||
Temperature float32
|
||||
Name string
|
||||
Default bool
|
||||
Provider ProviderType
|
||||
Endpoint string
|
||||
APIKey string
|
||||
Model, EmbeddingModel, TTSModel string
|
||||
Temperature float32
|
||||
}
|
||||
|
||||
type ProviderType string
|
||||
@@ -72,7 +87,7 @@ var defaultEndpoints = map[ProviderType]string{
|
||||
ProviderTypeOpenAI: "https://api.openai.com/v1",
|
||||
ProviderTypeOpenRouter: "https://openrouter.ai/api/v1",
|
||||
ProviderTypeDeepSeek: "https://api.deepseek.com/v1",
|
||||
ProviderTypeGemini: "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
ProviderTypeGemini: "https://generativelanguage.googleapis.com/v1beta",
|
||||
ProviderTypeVolc: "https://ark.cn-beijing.volces.com/api/v3",
|
||||
ProviderTypeSiliconFlow: "https://api.siliconflow.cn/v1",
|
||||
}
|
||||
@@ -97,8 +112,8 @@ func (c *Config) Validate() error { //nolint:cyclop
|
||||
if c.APIKey == "" {
|
||||
return errors.New("api key is required")
|
||||
}
|
||||
if c.Model == "" && c.EmbeddingModel == "" {
|
||||
return errors.New("model or embedding model is required")
|
||||
if c.Model == "" && c.EmbeddingModel == "" && c.TTSModel == "" {
|
||||
return errors.New("model or embedding model or tts model is required")
|
||||
}
|
||||
if c.Temperature < 0 || c.Temperature > 2 {
|
||||
return errors.Errorf("invalid temperature: %f, should be in range [0, 2]", c.Temperature)
|
||||
@@ -182,6 +197,7 @@ func (c *FactoryConfig) From(app *config.App) {
|
||||
APIKey: llm.APIKey,
|
||||
Model: llm.Model,
|
||||
EmbeddingModel: llm.EmbeddingModel,
|
||||
TTSModel: llm.TTSModel,
|
||||
Temperature: llm.Temperature,
|
||||
})
|
||||
}
|
||||
@@ -207,12 +223,9 @@ func NewFactory(
|
||||
) (Factory, error) {
|
||||
if len(mockOn) > 0 {
|
||||
mf := &mockFactory{}
|
||||
getCall := mf.On("Get", mock.Anything)
|
||||
getCall.Run(func(args mock.Arguments) {
|
||||
m := &mockLLM{}
|
||||
component.MockOptions(mockOn).Apply(&m.Mock)
|
||||
getCall.Return(m, nil)
|
||||
})
|
||||
m := &mockLLM{}
|
||||
component.MockOptions(mockOn).Apply(&m.Mock)
|
||||
mf.On("Get", mock.Anything).Return(m)
|
||||
mf.On("Reload", mock.Anything).Return(nil)
|
||||
|
||||
return mf, nil
|
||||
@@ -307,11 +320,6 @@ func (f *factory) Get(name string) LLM {
|
||||
continue
|
||||
}
|
||||
|
||||
if f.llms[name] == nil {
|
||||
llm := f.new(&llmC)
|
||||
f.llms[name] = llm
|
||||
}
|
||||
|
||||
return f.llms[name]
|
||||
}
|
||||
|
||||
@@ -320,8 +328,12 @@ func (f *factory) Get(name string) LLM {
|
||||
|
||||
func (f *factory) new(c *Config) LLM {
|
||||
switch c.Provider {
|
||||
case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek, ProviderTypeGemini, ProviderTypeVolc, ProviderTypeSiliconFlow: //nolint:lll
|
||||
case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek, ProviderTypeVolc, ProviderTypeSiliconFlow: //nolint:lll
|
||||
return newCached(newOpenAI(c), f.Dependencies().KVStorage)
|
||||
|
||||
case ProviderTypeGemini:
|
||||
return newCached(newGemini(c), f.Dependencies().KVStorage)
|
||||
|
||||
default:
|
||||
return newCached(newOpenAI(c), f.Dependencies().KVStorage)
|
||||
}
|
||||
@@ -333,14 +345,17 @@ func (f *factory) initLLMs() {
|
||||
llms = make(map[string]LLM, len(config.LLMs))
|
||||
defaultLLM LLM
|
||||
)
|
||||
|
||||
for _, llmC := range config.LLMs {
|
||||
llm := f.new(&llmC)
|
||||
|
||||
llms[llmC.Name] = llm
|
||||
|
||||
if llmC.Name == config.defaultLLM {
|
||||
defaultLLM = llm
|
||||
}
|
||||
}
|
||||
|
||||
f.llms = llms
|
||||
f.defaultLLM = defaultLLM
|
||||
}
|
||||
@@ -482,12 +497,27 @@ func (m *mockLLM) String(ctx context.Context, messages []string) (string, error)
|
||||
|
||||
func (m *mockLLM) EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error) {
|
||||
args := m.Called(ctx, labels)
|
||||
if args.Error(1) != nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).([][]float32), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockLLM) Embedding(ctx context.Context, text string) ([]float32, error) {
|
||||
args := m.Called(ctx, text)
|
||||
if args.Error(1) != nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).([]float32), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockLLM) WAV(ctx context.Context, text string, speakers []Speaker) (io.ReadCloser, error) {
|
||||
args := m.Called(ctx, text, speakers)
|
||||
if args.Error(1) != nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(io.ReadCloser), args.Error(1)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ package llm
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
oai "github.com/sashabaranov/go-openai"
|
||||
@@ -31,9 +32,7 @@ import (
|
||||
|
||||
type openai struct {
|
||||
*component.Base[Config, struct{}]
|
||||
|
||||
client *oai.Client
|
||||
embeddingSpliter embeddingSpliter
|
||||
text
|
||||
}
|
||||
|
||||
func newOpenAI(c *Config) LLM {
|
||||
@@ -42,18 +41,34 @@ func newOpenAI(c *Config) LLM {
|
||||
client := oai.NewClientWithConfig(config)
|
||||
embeddingSpliter := newEmbeddingSpliter(1536, 64)
|
||||
|
||||
base := component.New(&component.BaseConfig[Config, struct{}]{
|
||||
Name: "LLM/openai",
|
||||
Instance: c.Name,
|
||||
Config: c,
|
||||
})
|
||||
|
||||
return &openai{
|
||||
Base: component.New(&component.BaseConfig[Config, struct{}]{
|
||||
Name: "LLM/openai",
|
||||
Instance: c.Name,
|
||||
Config: c,
|
||||
}),
|
||||
client: client,
|
||||
embeddingSpliter: embeddingSpliter,
|
||||
Base: base,
|
||||
text: &openaiText{
|
||||
Base: base,
|
||||
client: client,
|
||||
embeddingSpliter: embeddingSpliter,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (o *openai) String(ctx context.Context, messages []string) (value string, err error) {
|
||||
func (o *openai) WAV(ctx context.Context, text string, speakers []Speaker) (r io.ReadCloser, err error) {
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
|
||||
type openaiText struct {
|
||||
*component.Base[Config, struct{}]
|
||||
|
||||
client *oai.Client
|
||||
embeddingSpliter embeddingSpliter
|
||||
}
|
||||
|
||||
func (o *openaiText) String(ctx context.Context, messages []string) (value string, err error) {
|
||||
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "String")...)
|
||||
defer func() { telemetry.End(ctx, err) }()
|
||||
|
||||
@@ -91,7 +106,7 @@ func (o *openai) String(ctx context.Context, messages []string) (value string, e
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func (o *openai) EmbeddingLabels(ctx context.Context, labels model.Labels) (value [][]float32, err error) {
|
||||
func (o *openaiText) EmbeddingLabels(ctx context.Context, labels model.Labels) (value [][]float32, err error) {
|
||||
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "EmbeddingLabels")...)
|
||||
defer func() { telemetry.End(ctx, err) }()
|
||||
|
||||
@@ -117,7 +132,7 @@ func (o *openai) EmbeddingLabels(ctx context.Context, labels model.Labels) (valu
|
||||
return vecs, nil
|
||||
}
|
||||
|
||||
func (o *openai) Embedding(ctx context.Context, s string) (value []float32, err error) {
|
||||
func (o *openaiText) Embedding(ctx context.Context, s string) (value []float32, err error) {
|
||||
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "Embedding")...)
|
||||
defer func() { telemetry.End(ctx, err) }()
|
||||
|
||||
|
||||
@@ -17,9 +17,13 @@ package rewrite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -30,10 +34,12 @@ import (
|
||||
"github.com/glidea/zenfeed/pkg/llm"
|
||||
"github.com/glidea/zenfeed/pkg/llm/prompt"
|
||||
"github.com/glidea/zenfeed/pkg/model"
|
||||
"github.com/glidea/zenfeed/pkg/storage/object"
|
||||
"github.com/glidea/zenfeed/pkg/telemetry"
|
||||
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
|
||||
"github.com/glidea/zenfeed/pkg/util/buffer"
|
||||
"github.com/glidea/zenfeed/pkg/util/crawl"
|
||||
hashutil "github.com/glidea/zenfeed/pkg/util/hash"
|
||||
)
|
||||
|
||||
// --- Interface code block ---
|
||||
@@ -68,7 +74,8 @@ func (c *Config) From(app *config.App) {
|
||||
}
|
||||
|
||||
type Dependencies struct {
|
||||
LLMFactory llm.Factory
|
||||
LLMFactory llm.Factory // NOTE: String() with cache.
|
||||
ObjectStorage object.Storage
|
||||
}
|
||||
|
||||
type Rule struct {
|
||||
@@ -120,32 +127,98 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen
|
||||
}
|
||||
|
||||
// Transform.
|
||||
if r.Transform != nil {
|
||||
if r.Transform.ToText == nil {
|
||||
return errors.New("to_text is required when transform is set")
|
||||
if r.Transform != nil { //nolint:nestif
|
||||
if r.Transform.ToText != nil && r.Transform.ToPodcast != nil {
|
||||
return errors.New("to_text and to_podcast can not be set at same time")
|
||||
}
|
||||
if r.Transform.ToText == nil && r.Transform.ToPodcast == nil {
|
||||
return errors.New("either to_text or to_podcast must be set when transform is set")
|
||||
}
|
||||
|
||||
switch r.Transform.ToText.Type {
|
||||
case ToTextTypePrompt:
|
||||
if r.Transform.ToText.Prompt == "" {
|
||||
return errors.New("to text prompt is required for prompt type")
|
||||
if r.Transform.ToText != nil {
|
||||
switch r.Transform.ToText.Type {
|
||||
case ToTextTypePrompt:
|
||||
if r.Transform.ToText.Prompt == "" {
|
||||
return errors.New("to text prompt is required for prompt type")
|
||||
}
|
||||
tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt)
|
||||
}
|
||||
|
||||
buf := buffer.Get()
|
||||
defer buffer.Put(buf)
|
||||
if err := tmpl.Execute(buf, prompt.Builtin); err != nil {
|
||||
return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt)
|
||||
}
|
||||
r.Transform.ToText.promptRendered = buf.String()
|
||||
|
||||
case ToTextTypeCrawl, ToTextTypeCrawlByJina:
|
||||
// No specific validation for crawl type here, as the source text itself is the URL.
|
||||
default:
|
||||
return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type)
|
||||
}
|
||||
tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt)
|
||||
}
|
||||
|
||||
if r.Transform.ToPodcast != nil {
|
||||
if len(r.Transform.ToPodcast.Speakers) == 0 {
|
||||
return errors.New("at least one speaker is required for to_podcast")
|
||||
}
|
||||
|
||||
buf := buffer.Get()
|
||||
defer buffer.Put(buf)
|
||||
if err := tmpl.Execute(buf, prompt.Builtin); err != nil {
|
||||
return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt)
|
||||
}
|
||||
r.Transform.ToText.promptRendered = buf.String()
|
||||
r.Transform.ToPodcast.speakers = make([]llm.Speaker, len(r.Transform.ToPodcast.Speakers))
|
||||
var speakerDescs []string
|
||||
var speakerNames []string
|
||||
for i, s := range r.Transform.ToPodcast.Speakers {
|
||||
if s.Name == "" {
|
||||
return errors.New("speaker name is required")
|
||||
}
|
||||
if s.Voice == "" {
|
||||
return errors.New("speaker voice is required")
|
||||
}
|
||||
r.Transform.ToPodcast.speakers[i] = llm.Speaker{Name: s.Name, Voice: s.Voice}
|
||||
|
||||
case ToTextTypeCrawl, ToTextTypeCrawlByJina:
|
||||
// No specific validation for crawl type here, as the source text itself is the URL.
|
||||
default:
|
||||
return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type)
|
||||
desc := s.Name
|
||||
if s.Role != "" {
|
||||
desc += " (" + s.Role + ")"
|
||||
}
|
||||
speakerDescs = append(speakerDescs, desc)
|
||||
speakerNames = append(speakerNames, s.Name)
|
||||
}
|
||||
|
||||
speakersDesc := "- " + strings.Join(speakerDescs, "\n- ")
|
||||
exampleSpeaker1 := speakerNames[0]
|
||||
exampleSpeaker2 := exampleSpeaker1
|
||||
if len(speakerNames) > 1 {
|
||||
exampleSpeaker2 = speakerNames[1]
|
||||
}
|
||||
|
||||
promptSegments := []string{
|
||||
"Please convert the following article into a podcast dialogue script.",
|
||||
"The speakers are:\n" + speakersDesc,
|
||||
}
|
||||
|
||||
if r.Transform.ToPodcast.EstimateMaximumDuration > 0 {
|
||||
wordsPerMinute := 200
|
||||
totalMinutes := int(r.Transform.ToPodcast.EstimateMaximumDuration.Minutes())
|
||||
estimatedWords := totalMinutes * wordsPerMinute
|
||||
promptSegments = append(promptSegments, fmt.Sprintf("The script should be approximately %d words to fit within a %d-minute duration. If the original content is not sufficient, the script can be shorter as appropriate.", estimatedWords, totalMinutes))
|
||||
}
|
||||
|
||||
if r.Transform.ToPodcast.TranscriptAdditionalPrompt != "" {
|
||||
promptSegments = append(promptSegments, "Additional instructions: "+r.Transform.ToPodcast.TranscriptAdditionalPrompt)
|
||||
}
|
||||
|
||||
promptSegments = append(promptSegments,
|
||||
"The output format MUST be a script where each line starts with the speaker's name followed by a colon and a space.",
|
||||
"Do NOT include any other text, explanations, or formatting before or after the script.",
|
||||
"Do NOT use background music in the script.",
|
||||
"Do NOT include any greetings or farewells (e.g., 'Hello everyone', 'Welcome to our show', 'Goodbye').",
|
||||
fmt.Sprintf("Example of the required format:\n%s: Today we are discussing the article's main points.\n%s: Let's start with the first one.", exampleSpeaker1, exampleSpeaker2),
|
||||
"Now, convert the article.",
|
||||
)
|
||||
|
||||
r.Transform.ToPodcast.transcriptPrompt = strings.Join(promptSegments, "\n\n")
|
||||
r.Transform.ToPodcast.speakersDesc = speakersDesc
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,9 +233,10 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen
|
||||
r.matchRE = re
|
||||
|
||||
// Action.
|
||||
switch r.Action {
|
||||
case "":
|
||||
if r.Action == "" {
|
||||
r.Action = ActionCreateOrUpdateLabel
|
||||
}
|
||||
switch r.Action {
|
||||
case ActionCreateOrUpdateLabel:
|
||||
if r.Label == "" {
|
||||
return errors.New("label is required for create or update label action")
|
||||
@@ -179,7 +253,7 @@ func (r *Rule) From(c *config.RewriteRule) {
|
||||
r.If = c.If
|
||||
r.SourceLabel = c.SourceLabel
|
||||
r.SkipTooShortThreshold = c.SkipTooShortThreshold
|
||||
if c.Transform != nil {
|
||||
if c.Transform != nil { //nolint:nestif
|
||||
t := &Transform{}
|
||||
if c.Transform.ToText != nil {
|
||||
toText := &ToText{
|
||||
@@ -192,6 +266,25 @@ func (r *Rule) From(c *config.RewriteRule) {
|
||||
}
|
||||
t.ToText = toText
|
||||
}
|
||||
if c.Transform.ToPodcast != nil {
|
||||
toPodcast := &ToPodcast{
|
||||
LLM: c.Transform.ToPodcast.LLM,
|
||||
EstimateMaximumDuration: time.Duration(c.Transform.ToPodcast.EstimateMaximumDuration),
|
||||
TranscriptAdditionalPrompt: c.Transform.ToPodcast.TranscriptAdditionalPrompt,
|
||||
TTSLLM: c.Transform.ToPodcast.TTSLLM,
|
||||
}
|
||||
if toPodcast.EstimateMaximumDuration == 0 {
|
||||
toPodcast.EstimateMaximumDuration = 3 * time.Minute
|
||||
}
|
||||
for _, s := range c.Transform.ToPodcast.Speakers {
|
||||
toPodcast.Speakers = append(toPodcast.Speakers, Speaker{
|
||||
Name: s.Name,
|
||||
Role: s.Role,
|
||||
Voice: s.Voice,
|
||||
})
|
||||
}
|
||||
t.ToPodcast = toPodcast
|
||||
}
|
||||
r.Transform = t
|
||||
}
|
||||
r.Match = c.Match
|
||||
@@ -203,7 +296,8 @@ func (r *Rule) From(c *config.RewriteRule) {
|
||||
}
|
||||
|
||||
type Transform struct {
|
||||
ToText *ToText
|
||||
ToText *ToText
|
||||
ToPodcast *ToPodcast
|
||||
}
|
||||
|
||||
type ToText struct {
|
||||
@@ -220,6 +314,24 @@ type ToText struct {
|
||||
promptRendered string
|
||||
}
|
||||
|
||||
type ToPodcast struct {
|
||||
LLM string
|
||||
EstimateMaximumDuration time.Duration
|
||||
TranscriptAdditionalPrompt string
|
||||
TTSLLM string
|
||||
Speakers []Speaker
|
||||
|
||||
transcriptPrompt string
|
||||
speakersDesc string
|
||||
speakers []llm.Speaker
|
||||
}
|
||||
|
||||
type Speaker struct {
|
||||
Name string
|
||||
Role string
|
||||
Voice string
|
||||
}
|
||||
|
||||
type ToTextType string
|
||||
|
||||
const (
|
||||
@@ -310,13 +422,9 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m
|
||||
}
|
||||
|
||||
// Transform text if configured.
|
||||
text := sourceText
|
||||
if rule.Transform != nil && rule.Transform.ToText != nil {
|
||||
transformed, err := r.transformText(ctx, rule.Transform, sourceText)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "transform text")
|
||||
}
|
||||
text = transformed
|
||||
text, err := r.transform(ctx, rule.Transform, sourceText)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "transform")
|
||||
}
|
||||
|
||||
// Check if text matches the rule.
|
||||
@@ -338,18 +446,34 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m
|
||||
return labels, nil
|
||||
}
|
||||
|
||||
func (r *rewriter) transform(ctx context.Context, transform *Transform, sourceText string) (string, error) {
|
||||
if transform == nil {
|
||||
return sourceText, nil
|
||||
}
|
||||
|
||||
if transform.ToText != nil {
|
||||
return r.transformText(ctx, transform.ToText, sourceText)
|
||||
}
|
||||
|
||||
if transform.ToPodcast != nil {
|
||||
return r.transformPodcast(ctx, transform.ToPodcast, sourceText)
|
||||
}
|
||||
|
||||
return sourceText, nil
|
||||
}
|
||||
|
||||
// transformText transforms text using configured LLM or by crawling a URL.
|
||||
func (r *rewriter) transformText(ctx context.Context, transform *Transform, text string) (string, error) {
|
||||
switch transform.ToText.Type {
|
||||
func (r *rewriter) transformText(ctx context.Context, toText *ToText, text string) (string, error) {
|
||||
switch toText.Type {
|
||||
case ToTextTypeCrawl:
|
||||
return r.transformTextCrawl(ctx, r.crawler, text)
|
||||
case ToTextTypeCrawlByJina:
|
||||
return r.transformTextCrawl(ctx, r.jinaCrawler, text)
|
||||
|
||||
case ToTextTypePrompt:
|
||||
return r.transformTextPrompt(ctx, transform, text)
|
||||
return r.transformTextPrompt(ctx, toText, text)
|
||||
default:
|
||||
return r.transformTextPrompt(ctx, transform, text)
|
||||
return r.transformTextPrompt(ctx, toText, text)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,13 +487,13 @@ func (r *rewriter) transformTextCrawl(ctx context.Context, crawler crawl.Crawler
|
||||
}
|
||||
|
||||
// transformTextPrompt transforms text using configured LLM.
|
||||
func (r *rewriter) transformTextPrompt(ctx context.Context, transform *Transform, text string) (string, error) {
|
||||
func (r *rewriter) transformTextPrompt(ctx context.Context, toText *ToText, text string) (string, error) {
|
||||
// Get LLM instance.
|
||||
llm := r.Dependencies().LLMFactory.Get(transform.ToText.LLM)
|
||||
llm := r.Dependencies().LLMFactory.Get(toText.LLM)
|
||||
|
||||
// Call completion.
|
||||
result, err := llm.String(ctx, []string{
|
||||
transform.ToText.promptRendered,
|
||||
toText.promptRendered,
|
||||
text, // TODO: may place to first line to hit the model cache in different rewrite rules.
|
||||
})
|
||||
if err != nil {
|
||||
@@ -388,6 +512,71 @@ func (r *rewriter) transformTextHack(text string) string {
|
||||
return text
|
||||
}
|
||||
|
||||
var audioKey = func(transcript, ext string) string {
|
||||
hash := hashutil.Sum64(transcript)
|
||||
file := strconv.FormatUint(hash, 10) + "." + ext
|
||||
|
||||
return "podcasts/" + file
|
||||
}
|
||||
|
||||
func (r *rewriter) transformPodcast(ctx context.Context, toPodcast *ToPodcast, sourceText string) (url string, err error) {
|
||||
transcript, err := r.generateTranscript(ctx, toPodcast, sourceText)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "generate podcast transcript")
|
||||
}
|
||||
|
||||
audioKey := audioKey(transcript, "wav")
|
||||
url, err = r.Dependencies().ObjectStorage.Get(ctx, audioKey)
|
||||
switch {
|
||||
case err == nil:
|
||||
// May canceled at last time by reload, fast return.
|
||||
return url, nil
|
||||
case errors.Is(err, object.ErrNotFound):
|
||||
// Not found, generate new audio.
|
||||
default:
|
||||
return "", errors.Wrap(err, "get audio")
|
||||
}
|
||||
|
||||
audioStream, err := r.generateAudio(ctx, toPodcast, transcript)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "generate podcast audio")
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := audioStream.Close(); closeErr != nil {
|
||||
err = errors.Wrap(err, "close audio stream")
|
||||
}
|
||||
}()
|
||||
|
||||
url, err = r.Dependencies().ObjectStorage.Put(ctx, audioKey, audioStream, "audio/wav")
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "store podcast audio")
|
||||
}
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
func (r *rewriter) generateTranscript(ctx context.Context, toPodcast *ToPodcast, sourceText string) (string, error) {
|
||||
transcript, err := r.Dependencies().LLMFactory.Get(toPodcast.LLM).
|
||||
String(ctx, []string{toPodcast.transcriptPrompt, sourceText})
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "llm completion")
|
||||
}
|
||||
|
||||
return toPodcast.speakersDesc +
|
||||
"\n\nFollowed by the actual dialogue script:\n" +
|
||||
transcript, nil
|
||||
}
|
||||
|
||||
func (r *rewriter) generateAudio(ctx context.Context, toPodcast *ToPodcast, transcript string) (io.ReadCloser, error) {
|
||||
audioStream, err := r.Dependencies().LLMFactory.Get(toPodcast.TTSLLM).
|
||||
WAV(ctx, transcript, toPodcast.speakers)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "calling tts llm")
|
||||
}
|
||||
|
||||
return audioStream, nil
|
||||
}
|
||||
|
||||
type mockRewriter struct {
|
||||
component.Mock
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package rewrite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -12,6 +14,7 @@ import (
|
||||
"github.com/glidea/zenfeed/pkg/component"
|
||||
"github.com/glidea/zenfeed/pkg/llm"
|
||||
"github.com/glidea/zenfeed/pkg/model"
|
||||
"github.com/glidea/zenfeed/pkg/storage/object"
|
||||
"github.com/glidea/zenfeed/pkg/test"
|
||||
)
|
||||
|
||||
@@ -19,8 +22,9 @@ func TestLabels(t *testing.T) {
|
||||
RegisterTestingT(t)
|
||||
|
||||
type givenDetail struct {
|
||||
config *Config
|
||||
llmMock func(m *mock.Mock)
|
||||
config *Config
|
||||
llmMock func(m *mock.Mock)
|
||||
objectStorageMock func(m *mock.Mock)
|
||||
}
|
||||
type whenDetail struct {
|
||||
inputLabels model.Labels
|
||||
@@ -173,7 +177,7 @@ func TestLabels(t *testing.T) {
|
||||
},
|
||||
ThenExpected: thenExpected{
|
||||
outputLabels: nil,
|
||||
err: errors.New("transform text: llm completion: LLM failed"),
|
||||
err: errors.New("transform: llm completion: LLM failed"),
|
||||
isErr: true,
|
||||
},
|
||||
},
|
||||
@@ -220,22 +224,163 @@ func TestLabels(t *testing.T) {
|
||||
isErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "Successfully generate podcast from content",
|
||||
Given: "a rule to convert content to a podcast with all dependencies mocked to succeed",
|
||||
When: "processing labels with content to be converted to a podcast",
|
||||
Then: "should return labels with a new podcast_url label",
|
||||
GivenDetail: givenDetail{
|
||||
config: &Config{
|
||||
{
|
||||
SourceLabel: model.LabelContent,
|
||||
Transform: &Transform{
|
||||
ToPodcast: &ToPodcast{
|
||||
LLM: "mock-llm-transcript",
|
||||
TTSLLM: "mock-llm-tts",
|
||||
Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}},
|
||||
},
|
||||
},
|
||||
Action: ActionCreateOrUpdateLabel,
|
||||
Label: "podcast_url",
|
||||
},
|
||||
},
|
||||
llmMock: func(m *mock.Mock) {
|
||||
m.On("String", mock.Anything, mock.Anything).Return("This is the podcast script.", nil).Once()
|
||||
m.On("WAV", mock.Anything, mock.Anything, mock.AnythingOfType("[]llm.Speaker")).
|
||||
Return(io.NopCloser(strings.NewReader("fake audio data")), nil).Once()
|
||||
},
|
||||
objectStorageMock: func(m *mock.Mock) {
|
||||
m.On("Put", mock.Anything, mock.AnythingOfType("string"), mock.Anything, "audio/wav").
|
||||
Return("http://storage.example.com/podcast.wav", nil).Once()
|
||||
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
|
||||
},
|
||||
},
|
||||
WhenDetail: whenDetail{
|
||||
inputLabels: model.Labels{
|
||||
{Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."},
|
||||
},
|
||||
},
|
||||
ThenExpected: thenExpected{
|
||||
outputLabels: model.Labels{
|
||||
{Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."},
|
||||
{Key: "podcast_url", Value: "http://storage.example.com/podcast.wav"},
|
||||
},
|
||||
isErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "Fail podcast generation due to transcription LLM error",
|
||||
Given: "a rule to convert content to a podcast, but the transcription LLM is mocked to fail",
|
||||
When: "processing labels",
|
||||
Then: "should return an error related to transcription failure",
|
||||
GivenDetail: givenDetail{
|
||||
config: &Config{
|
||||
{
|
||||
SourceLabel: model.LabelContent,
|
||||
Transform: &Transform{
|
||||
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
|
||||
},
|
||||
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
|
||||
},
|
||||
},
|
||||
llmMock: func(m *mock.Mock) {
|
||||
m.On("String", mock.Anything, mock.Anything).Return("", errors.New("transcript failed")).Once()
|
||||
},
|
||||
},
|
||||
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
|
||||
ThenExpected: thenExpected{
|
||||
outputLabels: nil,
|
||||
err: errors.New("transform: generate podcast transcript: llm completion: transcript failed"),
|
||||
isErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "Fail podcast generation due to TTS LLM error",
|
||||
Given: "a rule to convert content to a podcast, but the TTS LLM is mocked to fail",
|
||||
When: "processing labels",
|
||||
Then: "should return an error related to TTS failure",
|
||||
GivenDetail: givenDetail{
|
||||
config: &Config{
|
||||
{
|
||||
SourceLabel: model.LabelContent,
|
||||
Transform: &Transform{
|
||||
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
|
||||
},
|
||||
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
|
||||
},
|
||||
},
|
||||
llmMock: func(m *mock.Mock) {
|
||||
m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once()
|
||||
m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("tts failed")).Once()
|
||||
},
|
||||
objectStorageMock: func(m *mock.Mock) {
|
||||
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
|
||||
},
|
||||
},
|
||||
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
|
||||
ThenExpected: thenExpected{
|
||||
outputLabels: nil,
|
||||
err: errors.New("transform: generate podcast audio: calling tts llm: tts failed"),
|
||||
isErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "Fail podcast generation due to object storage error",
|
||||
Given: "a rule to convert content to a podcast, but object storage is mocked to fail",
|
||||
When: "processing labels",
|
||||
Then: "should return an error related to storage failure",
|
||||
GivenDetail: givenDetail{
|
||||
config: &Config{
|
||||
{
|
||||
SourceLabel: model.LabelContent,
|
||||
Transform: &Transform{
|
||||
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
|
||||
},
|
||||
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
|
||||
},
|
||||
},
|
||||
llmMock: func(m *mock.Mock) {
|
||||
m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once()
|
||||
m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(io.NopCloser(strings.NewReader("fake audio")), nil).Once()
|
||||
},
|
||||
objectStorageMock: func(m *mock.Mock) {
|
||||
m.On("Put", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return("", errors.New("storage failed")).Once()
|
||||
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
|
||||
},
|
||||
},
|
||||
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
|
||||
ThenExpected: thenExpected{
|
||||
outputLabels: nil,
|
||||
err: errors.New("transform: store podcast audio: storage failed"),
|
||||
isErr: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.Scenario, func(t *testing.T) {
|
||||
// Given.
|
||||
var mockLLMFactory llm.Factory
|
||||
var mockInstance *mock.Mock // Store the mock instance for assertion
|
||||
|
||||
// Create mock factory and capture the mock.Mock instance.
|
||||
mockOption := component.MockOption(func(m *mock.Mock) {
|
||||
mockInstance = m // Capture the mock instance.
|
||||
var mockLLMInstance *mock.Mock
|
||||
llmMockOption := component.MockOption(func(m *mock.Mock) {
|
||||
mockLLMInstance = m
|
||||
if tt.GivenDetail.llmMock != nil {
|
||||
tt.GivenDetail.llmMock(m)
|
||||
}
|
||||
})
|
||||
mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, mockOption) // Use the factory directly with the option
|
||||
mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, llmMockOption)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var mockObjectStorage object.Storage
|
||||
var mockObjectStorageInstance *mock.Mock
|
||||
objectStorageMockOption := component.MockOption(func(m *mock.Mock) {
|
||||
mockObjectStorageInstance = m
|
||||
if tt.GivenDetail.objectStorageMock != nil {
|
||||
tt.GivenDetail.objectStorageMock(m)
|
||||
}
|
||||
})
|
||||
mockObjectStorageFactory := object.NewFactory(objectStorageMockOption)
|
||||
mockObjectStorage, err = mockObjectStorageFactory.New("test", nil, object.Dependencies{})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Manually validate config to compile regex and render templates.
|
||||
@@ -252,7 +397,8 @@ func TestLabels(t *testing.T) {
|
||||
Instance: "test",
|
||||
Config: tt.GivenDetail.config,
|
||||
Dependencies: Dependencies{
|
||||
LLMFactory: mockLLMFactory, // Pass the mock factory
|
||||
LLMFactory: mockLLMFactory, // Pass the mock factory
|
||||
ObjectStorage: mockObjectStorage,
|
||||
},
|
||||
}),
|
||||
}
|
||||
@@ -280,10 +426,12 @@ func TestLabels(t *testing.T) {
|
||||
Expect(outputLabels).To(Equal(tt.ThenExpected.outputLabels))
|
||||
}
|
||||
|
||||
// Verify LLM calls if stubs were provided.
|
||||
if tt.GivenDetail.llmMock != nil && mockInstance != nil {
|
||||
// Assert expectations on the captured mock instance.
|
||||
mockInstance.AssertExpectations(t)
|
||||
// Verify mock calls if stubs were provided.
|
||||
if tt.GivenDetail.llmMock != nil && mockLLMInstance != nil {
|
||||
mockLLMInstance.AssertExpectations(t)
|
||||
}
|
||||
if tt.GivenDetail.objectStorageMock != nil && mockObjectStorageInstance != nil {
|
||||
mockObjectStorageInstance.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -216,6 +216,10 @@ func (m *manager) reload(config *Config) (err error) {
|
||||
func (m *manager) runOrRestartScrapers(config *Config, newScrapers map[string]scraper.Scraper) error {
|
||||
for i := range config.Scrapers {
|
||||
c := &config.Scrapers[i]
|
||||
if err := c.Validate(); err != nil {
|
||||
return errors.Wrapf(err, "validate scraper %s", c.Name)
|
||||
}
|
||||
|
||||
if err := m.runOrRestartScraper(c, newScrapers); err != nil {
|
||||
return errors.Wrapf(err, "run or restart scraper %s", c.Name)
|
||||
}
|
||||
|
||||
@@ -69,6 +69,11 @@ func (c *Config) Validate() error {
|
||||
if c.Name == "" {
|
||||
return errors.New("name cannot be empty")
|
||||
}
|
||||
if c.RSS != nil {
|
||||
if err := c.RSS.Validate(); err != nil {
|
||||
return errors.Wrap(err, "invalid RSS config")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -244,7 +244,7 @@ func TestNew(t *testing.T) {
|
||||
WhenDetail: whenDetail{},
|
||||
ThenExpected: thenExpected{
|
||||
isErr: true,
|
||||
wantErrMsg: "creating source: invalid RSS config: URL must be a valid HTTP/HTTPS URL", // Error from newRSSReader via newReader
|
||||
wantErrMsg: "invalid RSS config: URL must be a valid HTTP/HTTPS URL", // Error from newRSSReader via newReader
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -264,7 +264,7 @@ func TestNew(t *testing.T) {
|
||||
WhenDetail: whenDetail{},
|
||||
ThenExpected: thenExpected{
|
||||
isErr: true,
|
||||
wantErrMsg: "creating source: source not supported", // Error from newReader
|
||||
wantErrMsg: "source not supported", // Error from newReader
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
235
pkg/storage/object/object.go
Normal file
235
pkg/storage/object/object.go
Normal file
@@ -0,0 +1,235 @@
|
||||
// Copyright (C) 2025 wangyusong
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package object
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/glidea/zenfeed/pkg/component"
|
||||
"github.com/glidea/zenfeed/pkg/config"
|
||||
"github.com/glidea/zenfeed/pkg/telemetry"
|
||||
"github.com/glidea/zenfeed/pkg/telemetry/log"
|
||||
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
|
||||
)
|
||||
|
||||
// --- Interface code block ---
|
||||
type Storage interface {
|
||||
component.Component
|
||||
config.Watcher
|
||||
Put(ctx context.Context, key string, body io.Reader, contentType string) (url string, err error)
|
||||
Get(ctx context.Context, key string) (url string, err error)
|
||||
}
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
type Config struct {
|
||||
Endpoint string
|
||||
AccessKeyID string
|
||||
SecretAccessKey string
|
||||
Bucket string
|
||||
BucketURL string
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if c.Endpoint == "" {
|
||||
return errors.New("endpoint is required")
|
||||
}
|
||||
if c.AccessKeyID == "" {
|
||||
return errors.New("access key id is required")
|
||||
}
|
||||
if c.SecretAccessKey == "" {
|
||||
return errors.New("secret access key is required")
|
||||
}
|
||||
if c.Bucket == "" {
|
||||
return errors.New("bucket is required")
|
||||
}
|
||||
if c.BucketURL == "" {
|
||||
return errors.New("bucket url is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) From(app *config.App) *Config {
|
||||
*c = Config{
|
||||
Endpoint: app.Storage.Object.Endpoint,
|
||||
AccessKeyID: app.Storage.Object.AccessKeyID,
|
||||
SecretAccessKey: app.Storage.Object.SecretAccessKey,
|
||||
Bucket: app.Storage.Object.Bucket,
|
||||
BucketURL: app.Storage.Object.BucketURL,
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
type Dependencies struct{}
|
||||
|
||||
// --- Factory code block ---
|
||||
type Factory component.Factory[Storage, config.App, Dependencies]
|
||||
|
||||
func NewFactory(mockOn ...component.MockOption) Factory {
|
||||
if len(mockOn) > 0 {
|
||||
return component.FactoryFunc[Storage, config.App, Dependencies](
|
||||
func(instance string, config *config.App, dependencies Dependencies) (Storage, error) {
|
||||
m := &mockStorage{}
|
||||
component.MockOptions(mockOn).Apply(&m.Mock)
|
||||
|
||||
return m, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
return component.FactoryFunc[Storage, config.App, Dependencies](new)
|
||||
}
|
||||
|
||||
func new(instance string, app *config.App, dependencies Dependencies) (Storage, error) {
|
||||
config := &Config{}
|
||||
config.From(app)
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, errors.Wrap(err, "validate config")
|
||||
}
|
||||
|
||||
client, err := minio.New(config.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(config.AccessKeyID, config.SecretAccessKey, ""),
|
||||
Secure: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "new minio client")
|
||||
}
|
||||
|
||||
u, err := url.Parse(config.BucketURL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parse public url")
|
||||
}
|
||||
|
||||
return &s3{
|
||||
Base: component.New(&component.BaseConfig[Config, Dependencies]{
|
||||
Name: "ObjectStorage",
|
||||
Instance: instance,
|
||||
Config: config,
|
||||
Dependencies: dependencies,
|
||||
}),
|
||||
client: client,
|
||||
bucketURL: u,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// --- Implementation code block ---
|
||||
type s3 struct {
|
||||
*component.Base[Config, Dependencies]
|
||||
|
||||
client *minio.Client
|
||||
bucketURL *url.URL
|
||||
}
|
||||
|
||||
func (s *s3) Put(ctx context.Context, key string, body io.Reader, contentType string) (publicURL string, err error) {
|
||||
ctx = telemetry.StartWith(ctx, append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Put")...)
|
||||
defer func() { telemetry.End(ctx, err) }()
|
||||
bucket := s.Config().Bucket
|
||||
|
||||
if _, err := s.client.PutObject(ctx, bucket, key, body, -1, minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
}); err != nil {
|
||||
return "", errors.Wrap(err, "put object")
|
||||
}
|
||||
|
||||
return s.bucketURL.JoinPath(key).String(), nil
|
||||
}
|
||||
|
||||
func (s *s3) Get(ctx context.Context, key string) (publicURL string, err error) {
|
||||
ctx = telemetry.StartWith(ctx, append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Get")...)
|
||||
defer func() { telemetry.End(ctx, err) }()
|
||||
bucket := s.Config().Bucket
|
||||
|
||||
if _, err := s.client.StatObject(ctx, bucket, key, minio.StatObjectOptions{}); err != nil {
|
||||
errResponse := minio.ToErrorResponse(err)
|
||||
if errResponse.Code == minio.NoSuchKey {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
return "", errors.Wrap(err, "stat object")
|
||||
}
|
||||
|
||||
return s.bucketURL.JoinPath(key).String(), nil
|
||||
}
|
||||
|
||||
func (s *s3) Reload(app *config.App) (err error) {
|
||||
ctx := telemetry.StartWith(s.Context(), append(s.TelemetryLabels(), telemetrymodel.KeyOperation, "Reload")...)
|
||||
defer func() { telemetry.End(ctx, err) }()
|
||||
|
||||
newConfig := &Config{}
|
||||
newConfig.From(app)
|
||||
if err := newConfig.Validate(); err != nil {
|
||||
return errors.Wrap(err, "validate config")
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(s.Config(), newConfig) {
|
||||
log.Debug(ctx, "object storage config not changed")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
client, err := minio.New(newConfig.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(newConfig.AccessKeyID, newConfig.SecretAccessKey, ""),
|
||||
Secure: true,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "new minio client")
|
||||
}
|
||||
|
||||
u, err := url.Parse(newConfig.BucketURL)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "parse public url")
|
||||
}
|
||||
|
||||
s.client = client
|
||||
s.bucketURL = u
|
||||
s.SetConfig(newConfig)
|
||||
|
||||
log.Info(ctx, "object storage reloaded")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Mock code block ---
|
||||
type mockStorage struct {
|
||||
component.Mock
|
||||
}
|
||||
|
||||
func (m *mockStorage) Put(ctx context.Context, key string, body io.Reader, contentType string) (string, error) {
|
||||
args := m.Called(ctx, key, body, contentType)
|
||||
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockStorage) Get(ctx context.Context, key string) (string, error) {
|
||||
args := m.Called(ctx, key)
|
||||
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockStorage) Reload(app *config.App) error {
|
||||
args := m.Called(app)
|
||||
|
||||
return args.Error(0)
|
||||
}
|
||||
@@ -122,6 +122,32 @@ func ReadUint32(r io.Reader) (uint32, error) {
|
||||
return binary.LittleEndian.Uint32(b), nil
|
||||
}
|
||||
|
||||
// WriteUint16 writes a uint16 using a pooled buffer.
|
||||
func WriteUint16(w io.Writer, v uint16) error {
|
||||
bp := smallBufPool.Get().(*[]byte)
|
||||
defer smallBufPool.Put(bp)
|
||||
b := *bp
|
||||
|
||||
binary.LittleEndian.PutUint16(b, v)
|
||||
_, err := w.Write(b[:2])
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ReadUint16 reads a uint16 using a pooled buffer.
|
||||
func ReadUint16(r io.Reader) (uint16, error) {
|
||||
bp := smallBufPool.Get().(*[]byte)
|
||||
defer smallBufPool.Put(bp)
|
||||
b := (*bp)[:2]
|
||||
|
||||
// Read exactly 2 bytes into the slice.
|
||||
if _, err := io.ReadFull(r, b); err != nil {
|
||||
return 0, errors.Wrap(err, "read uint16")
|
||||
}
|
||||
|
||||
return binary.LittleEndian.Uint16(b), nil
|
||||
}
|
||||
|
||||
// WriteFloat32 writes a float32 using a pooled buffer.
|
||||
func WriteFloat32(w io.Writer, v float32) error {
|
||||
return WriteUint32(w, math.Float32bits(v))
|
||||
|
||||
100
pkg/util/wav/wav.go
Normal file
100
pkg/util/wav/wav.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (C) 2025 wangyusong
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package wav
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
binaryutil "github.com/glidea/zenfeed/pkg/util/binary"
|
||||
)
|
||||
|
||||
// Header contains the WAV header information.
|
||||
type Header struct {
|
||||
SampleRate uint32
|
||||
BitDepth uint16
|
||||
NumChannels uint16
|
||||
}
|
||||
|
||||
// WriteHeader writes the WAV header to a writer.
|
||||
// pcmDataSize is the size of the raw PCM data.
|
||||
func WriteHeader(w io.Writer, h *Header, pcmDataSize uint32) error {
|
||||
// RIFF Header.
|
||||
if err := writeRIFFHeader(w, pcmDataSize); err != nil {
|
||||
return errors.Wrap(err, "write RIFF header")
|
||||
}
|
||||
|
||||
// fmt chunk.
|
||||
if err := writeFMTChunk(w, h); err != nil {
|
||||
return errors.Wrap(err, "write fmt chunk")
|
||||
}
|
||||
|
||||
// data chunk.
|
||||
if _, err := w.Write([]byte("data")); err != nil {
|
||||
return errors.Wrap(err, "write data chunk marker")
|
||||
}
|
||||
if err := binaryutil.WriteUint32(w, pcmDataSize); err != nil {
|
||||
return errors.Wrap(err, "write pcm data size")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeRIFFHeader(w io.Writer, pcmDataSize uint32) error {
|
||||
if _, err := w.Write([]byte("RIFF")); err != nil {
|
||||
return errors.Wrap(err, "write RIFF")
|
||||
}
|
||||
if err := binaryutil.WriteUint32(w, uint32(36+pcmDataSize)); err != nil {
|
||||
return errors.Wrap(err, "write file size")
|
||||
}
|
||||
if _, err := w.Write([]byte("WAVE")); err != nil {
|
||||
return errors.Wrap(err, "write WAVE")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeFMTChunk(w io.Writer, h *Header) error {
|
||||
if _, err := w.Write([]byte("fmt ")); err != nil {
|
||||
return errors.Wrap(err, "write fmt")
|
||||
}
|
||||
if err := binaryutil.WriteUint32(w, uint32(16)); err != nil { // PCM chunk size.
|
||||
return errors.Wrap(err, "write pcm chunk size")
|
||||
}
|
||||
if err := binaryutil.WriteUint16(w, uint16(1)); err != nil { // PCM format.
|
||||
return errors.Wrap(err, "write pcm format")
|
||||
}
|
||||
if err := binaryutil.WriteUint16(w, h.NumChannels); err != nil {
|
||||
return errors.Wrap(err, "write num channels")
|
||||
}
|
||||
if err := binaryutil.WriteUint32(w, h.SampleRate); err != nil {
|
||||
return errors.Wrap(err, "write sample rate")
|
||||
}
|
||||
byteRate := h.SampleRate * uint32(h.NumChannels) * uint32(h.BitDepth) / 8
|
||||
if err := binaryutil.WriteUint32(w, byteRate); err != nil {
|
||||
return errors.Wrap(err, "write byte rate")
|
||||
}
|
||||
blockAlign := h.NumChannels * h.BitDepth / 8
|
||||
if err := binaryutil.WriteUint16(w, blockAlign); err != nil {
|
||||
return errors.Wrap(err, "write block align")
|
||||
}
|
||||
if err := binaryutil.WriteUint16(w, h.BitDepth); err != nil {
|
||||
return errors.Wrap(err, "write bit depth")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
161
pkg/util/wav/wav_test.go
Normal file
161
pkg/util/wav/wav_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
// Copyright (C) 2025 wangyusong
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package wav
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/glidea/zenfeed/pkg/test"
|
||||
)
|
||||
|
||||
func TestWriteHeader(t *testing.T) {
|
||||
RegisterTestingT(t)
|
||||
|
||||
type givenDetail struct{}
|
||||
|
||||
type whenDetail struct {
|
||||
header *Header
|
||||
pcmDataSize uint32
|
||||
}
|
||||
|
||||
type thenExpected struct {
|
||||
expectedBytes []byte
|
||||
expectError bool
|
||||
}
|
||||
|
||||
tests := []test.Case[givenDetail, whenDetail, thenExpected]{
|
||||
{
|
||||
Scenario: "Standard CD quality audio",
|
||||
Given: "a header for CD quality audio and a non-zero data size",
|
||||
When: "writing the header",
|
||||
Then: "should produce a valid 44-byte WAV header and no error",
|
||||
GivenDetail: givenDetail{},
|
||||
WhenDetail: whenDetail{
|
||||
header: &Header{
|
||||
SampleRate: 44100,
|
||||
BitDepth: 16,
|
||||
NumChannels: 2,
|
||||
},
|
||||
pcmDataSize: 176400,
|
||||
},
|
||||
ThenExpected: thenExpected{
|
||||
expectedBytes: []byte{
|
||||
'R', 'I', 'F', 'F',
|
||||
0x34, 0xB1, 0x02, 0x00, // ChunkSize = 36 + 176400 = 176436
|
||||
'W', 'A', 'V', 'E',
|
||||
'f', 'm', 't', ' ',
|
||||
0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16
|
||||
0x01, 0x00, // AudioFormat = 1 (PCM)
|
||||
0x02, 0x00, // NumChannels = 2
|
||||
0x44, 0xAC, 0x00, 0x00, // SampleRate = 44100
|
||||
0x10, 0xB1, 0x02, 0x00, // ByteRate = 176400
|
||||
0x04, 0x00, // BlockAlign = 4
|
||||
0x10, 0x00, // BitsPerSample = 16
|
||||
'd', 'a', 't', 'a',
|
||||
0x10, 0xB1, 0x02, 0x00, // Subchunk2Size = 176400
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "Mono audio for speech",
|
||||
Given: "a header for mono speech audio and a non-zero data size",
|
||||
When: "writing the header",
|
||||
Then: "should produce a valid 44-byte WAV header and no error",
|
||||
GivenDetail: givenDetail{},
|
||||
WhenDetail: whenDetail{
|
||||
header: &Header{
|
||||
SampleRate: 16000,
|
||||
BitDepth: 16,
|
||||
NumChannels: 1,
|
||||
},
|
||||
pcmDataSize: 32000,
|
||||
},
|
||||
ThenExpected: thenExpected{
|
||||
expectedBytes: []byte{
|
||||
'R', 'I', 'F', 'F',
|
||||
0x24, 0x7D, 0x00, 0x00, // ChunkSize = 36 + 32000 = 32036
|
||||
'W', 'A', 'V', 'E',
|
||||
'f', 'm', 't', ' ',
|
||||
0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16
|
||||
0x01, 0x00, // AudioFormat = 1
|
||||
0x01, 0x00, // NumChannels = 1
|
||||
0x80, 0x3E, 0x00, 0x00, // SampleRate = 16000
|
||||
0x00, 0x7D, 0x00, 0x00, // ByteRate = 32000
|
||||
0x02, 0x00, // BlockAlign = 2
|
||||
0x10, 0x00, // BitsPerSample = 16
|
||||
'd', 'a', 't', 'a',
|
||||
0x00, 0x7D, 0x00, 0x00, // Subchunk2Size = 32000
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "8-bit mono audio with zero data size",
|
||||
Given: "a header for 8-bit mono audio and a zero data size",
|
||||
When: "writing the header for an empty file",
|
||||
Then: "should produce a valid 44-byte WAV header with data size 0",
|
||||
GivenDetail: givenDetail{},
|
||||
WhenDetail: whenDetail{
|
||||
header: &Header{
|
||||
SampleRate: 8000,
|
||||
BitDepth: 8,
|
||||
NumChannels: 1,
|
||||
},
|
||||
pcmDataSize: 0,
|
||||
},
|
||||
ThenExpected: thenExpected{
|
||||
expectedBytes: []byte{
|
||||
'R', 'I', 'F', 'F',
|
||||
0x24, 0x00, 0x00, 0x00, // ChunkSize = 36 + 0 = 36
|
||||
'W', 'A', 'V', 'E',
|
||||
'f', 'm', 't', ' ',
|
||||
0x10, 0x00, 0x00, 0x00, // Subchunk1Size = 16
|
||||
0x01, 0x00, // AudioFormat = 1
|
||||
0x01, 0x00, // NumChannels = 1
|
||||
0x40, 0x1F, 0x00, 0x00, // SampleRate = 8000
|
||||
0x40, 0x1F, 0x00, 0x00, // ByteRate = 8000
|
||||
0x01, 0x00, // BlockAlign = 1
|
||||
0x08, 0x00, // BitsPerSample = 8
|
||||
'd', 'a', 't', 'a',
|
||||
0x00, 0x00, 0x00, 0x00, // Subchunk2Size = 0
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.Scenario, func(t *testing.T) {
|
||||
// Given.
|
||||
var buf bytes.Buffer
|
||||
|
||||
// When.
|
||||
err := WriteHeader(&buf, tt.WhenDetail.header, tt.WhenDetail.pcmDataSize)
|
||||
|
||||
// Then.
|
||||
if tt.ThenExpected.expectError {
|
||||
Expect(err).To(HaveOccurred())
|
||||
} else {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(buf.Bytes()).To(Equal(tt.ThenExpected.expectedBytes))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user