add podcast

This commit is contained in:
glidea
2025-07-08 18:13:26 +08:00
parent 263fcbbfaf
commit 2de0cf77fc
21 changed files with 1545 additions and 145 deletions

View File

@@ -17,9 +17,13 @@ package rewrite
import (
"context"
"fmt"
"html/template"
"io"
"regexp"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/pkg/errors"
@@ -30,10 +34,12 @@ import (
"github.com/glidea/zenfeed/pkg/llm"
"github.com/glidea/zenfeed/pkg/llm/prompt"
"github.com/glidea/zenfeed/pkg/model"
"github.com/glidea/zenfeed/pkg/storage/object"
"github.com/glidea/zenfeed/pkg/telemetry"
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
"github.com/glidea/zenfeed/pkg/util/buffer"
"github.com/glidea/zenfeed/pkg/util/crawl"
hashutil "github.com/glidea/zenfeed/pkg/util/hash"
)
// --- Interface code block ---
@@ -68,7 +74,8 @@ func (c *Config) From(app *config.App) {
}
type Dependencies struct {
LLMFactory llm.Factory
LLMFactory llm.Factory // NOTE: String() with cache.
ObjectStorage object.Storage
}
type Rule struct {
@@ -120,32 +127,98 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen
}
// Transform.
if r.Transform != nil {
if r.Transform.ToText == nil {
return errors.New("to_text is required when transform is set")
if r.Transform != nil { //nolint:nestif
if r.Transform.ToText != nil && r.Transform.ToPodcast != nil {
return errors.New("to_text and to_podcast can not be set at same time")
}
if r.Transform.ToText == nil && r.Transform.ToPodcast == nil {
return errors.New("either to_text or to_podcast must be set when transform is set")
}
switch r.Transform.ToText.Type {
case ToTextTypePrompt:
if r.Transform.ToText.Prompt == "" {
return errors.New("to text prompt is required for prompt type")
if r.Transform.ToText != nil {
switch r.Transform.ToText.Type {
case ToTextTypePrompt:
if r.Transform.ToText.Prompt == "" {
return errors.New("to text prompt is required for prompt type")
}
tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt)
if err != nil {
return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt)
}
buf := buffer.Get()
defer buffer.Put(buf)
if err := tmpl.Execute(buf, prompt.Builtin); err != nil {
return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt)
}
r.Transform.ToText.promptRendered = buf.String()
case ToTextTypeCrawl, ToTextTypeCrawlByJina:
// No specific validation for crawl type here, as the source text itself is the URL.
default:
return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type)
}
tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt)
if err != nil {
return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt)
}
if r.Transform.ToPodcast != nil {
if len(r.Transform.ToPodcast.Speakers) == 0 {
return errors.New("at least one speaker is required for to_podcast")
}
buf := buffer.Get()
defer buffer.Put(buf)
if err := tmpl.Execute(buf, prompt.Builtin); err != nil {
return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt)
}
r.Transform.ToText.promptRendered = buf.String()
r.Transform.ToPodcast.speakers = make([]llm.Speaker, len(r.Transform.ToPodcast.Speakers))
var speakerDescs []string
var speakerNames []string
for i, s := range r.Transform.ToPodcast.Speakers {
if s.Name == "" {
return errors.New("speaker name is required")
}
if s.Voice == "" {
return errors.New("speaker voice is required")
}
r.Transform.ToPodcast.speakers[i] = llm.Speaker{Name: s.Name, Voice: s.Voice}
case ToTextTypeCrawl, ToTextTypeCrawlByJina:
// No specific validation for crawl type here, as the source text itself is the URL.
default:
return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type)
desc := s.Name
if s.Role != "" {
desc += " (" + s.Role + ")"
}
speakerDescs = append(speakerDescs, desc)
speakerNames = append(speakerNames, s.Name)
}
speakersDesc := "- " + strings.Join(speakerDescs, "\n- ")
exampleSpeaker1 := speakerNames[0]
exampleSpeaker2 := exampleSpeaker1
if len(speakerNames) > 1 {
exampleSpeaker2 = speakerNames[1]
}
promptSegments := []string{
"Please convert the following article into a podcast dialogue script.",
"The speakers are:\n" + speakersDesc,
}
if r.Transform.ToPodcast.EstimateMaximumDuration > 0 {
wordsPerMinute := 200
totalMinutes := int(r.Transform.ToPodcast.EstimateMaximumDuration.Minutes())
estimatedWords := totalMinutes * wordsPerMinute
promptSegments = append(promptSegments, fmt.Sprintf("The script should be approximately %d words to fit within a %d-minute duration. If the original content is not sufficient, the script can be shorter as appropriate.", estimatedWords, totalMinutes))
}
if r.Transform.ToPodcast.TranscriptAdditionalPrompt != "" {
promptSegments = append(promptSegments, "Additional instructions: "+r.Transform.ToPodcast.TranscriptAdditionalPrompt)
}
promptSegments = append(promptSegments,
"The output format MUST be a script where each line starts with the speaker's name followed by a colon and a space.",
"Do NOT include any other text, explanations, or formatting before or after the script.",
"Do NOT use background music in the script.",
"Do NOT include any greetings or farewells (e.g., 'Hello everyone', 'Welcome to our show', 'Goodbye').",
fmt.Sprintf("Example of the required format:\n%s: Today we are discussing the article's main points.\n%s: Let's start with the first one.", exampleSpeaker1, exampleSpeaker2),
"Now, convert the article.",
)
r.Transform.ToPodcast.transcriptPrompt = strings.Join(promptSegments, "\n\n")
r.Transform.ToPodcast.speakersDesc = speakersDesc
}
}
@@ -160,9 +233,10 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen
r.matchRE = re
// Action.
switch r.Action {
case "":
if r.Action == "" {
r.Action = ActionCreateOrUpdateLabel
}
switch r.Action {
case ActionCreateOrUpdateLabel:
if r.Label == "" {
return errors.New("label is required for create or update label action")
@@ -179,7 +253,7 @@ func (r *Rule) From(c *config.RewriteRule) {
r.If = c.If
r.SourceLabel = c.SourceLabel
r.SkipTooShortThreshold = c.SkipTooShortThreshold
if c.Transform != nil {
if c.Transform != nil { //nolint:nestif
t := &Transform{}
if c.Transform.ToText != nil {
toText := &ToText{
@@ -192,6 +266,25 @@ func (r *Rule) From(c *config.RewriteRule) {
}
t.ToText = toText
}
if c.Transform.ToPodcast != nil {
toPodcast := &ToPodcast{
LLM: c.Transform.ToPodcast.LLM,
EstimateMaximumDuration: time.Duration(c.Transform.ToPodcast.EstimateMaximumDuration),
TranscriptAdditionalPrompt: c.Transform.ToPodcast.TranscriptAdditionalPrompt,
TTSLLM: c.Transform.ToPodcast.TTSLLM,
}
if toPodcast.EstimateMaximumDuration == 0 {
toPodcast.EstimateMaximumDuration = 3 * time.Minute
}
for _, s := range c.Transform.ToPodcast.Speakers {
toPodcast.Speakers = append(toPodcast.Speakers, Speaker{
Name: s.Name,
Role: s.Role,
Voice: s.Voice,
})
}
t.ToPodcast = toPodcast
}
r.Transform = t
}
r.Match = c.Match
@@ -203,7 +296,8 @@ func (r *Rule) From(c *config.RewriteRule) {
}
type Transform struct {
ToText *ToText
ToText *ToText
ToPodcast *ToPodcast
}
type ToText struct {
@@ -220,6 +314,24 @@ type ToText struct {
promptRendered string
}
type ToPodcast struct {
LLM string
EstimateMaximumDuration time.Duration
TranscriptAdditionalPrompt string
TTSLLM string
Speakers []Speaker
transcriptPrompt string
speakersDesc string
speakers []llm.Speaker
}
type Speaker struct {
Name string
Role string
Voice string
}
type ToTextType string
const (
@@ -310,13 +422,9 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m
}
// Transform text if configured.
text := sourceText
if rule.Transform != nil && rule.Transform.ToText != nil {
transformed, err := r.transformText(ctx, rule.Transform, sourceText)
if err != nil {
return nil, errors.Wrap(err, "transform text")
}
text = transformed
text, err := r.transform(ctx, rule.Transform, sourceText)
if err != nil {
return nil, errors.Wrap(err, "transform")
}
// Check if text matches the rule.
@@ -338,18 +446,34 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m
return labels, nil
}
func (r *rewriter) transform(ctx context.Context, transform *Transform, sourceText string) (string, error) {
if transform == nil {
return sourceText, nil
}
if transform.ToText != nil {
return r.transformText(ctx, transform.ToText, sourceText)
}
if transform.ToPodcast != nil {
return r.transformPodcast(ctx, transform.ToPodcast, sourceText)
}
return sourceText, nil
}
// transformText transforms text using configured LLM or by crawling a URL.
func (r *rewriter) transformText(ctx context.Context, transform *Transform, text string) (string, error) {
switch transform.ToText.Type {
func (r *rewriter) transformText(ctx context.Context, toText *ToText, text string) (string, error) {
switch toText.Type {
case ToTextTypeCrawl:
return r.transformTextCrawl(ctx, r.crawler, text)
case ToTextTypeCrawlByJina:
return r.transformTextCrawl(ctx, r.jinaCrawler, text)
case ToTextTypePrompt:
return r.transformTextPrompt(ctx, transform, text)
return r.transformTextPrompt(ctx, toText, text)
default:
return r.transformTextPrompt(ctx, transform, text)
return r.transformTextPrompt(ctx, toText, text)
}
}
@@ -363,13 +487,13 @@ func (r *rewriter) transformTextCrawl(ctx context.Context, crawler crawl.Crawler
}
// transformTextPrompt transforms text using configured LLM.
func (r *rewriter) transformTextPrompt(ctx context.Context, transform *Transform, text string) (string, error) {
func (r *rewriter) transformTextPrompt(ctx context.Context, toText *ToText, text string) (string, error) {
// Get LLM instance.
llm := r.Dependencies().LLMFactory.Get(transform.ToText.LLM)
llm := r.Dependencies().LLMFactory.Get(toText.LLM)
// Call completion.
result, err := llm.String(ctx, []string{
transform.ToText.promptRendered,
toText.promptRendered,
text, // TODO: may place to first line to hit the model cache in different rewrite rules.
})
if err != nil {
@@ -388,6 +512,71 @@ func (r *rewriter) transformTextHack(text string) string {
return text
}
var audioKey = func(transcript, ext string) string {
hash := hashutil.Sum64(transcript)
file := strconv.FormatUint(hash, 10) + "." + ext
return "podcasts/" + file
}
func (r *rewriter) transformPodcast(ctx context.Context, toPodcast *ToPodcast, sourceText string) (url string, err error) {
transcript, err := r.generateTranscript(ctx, toPodcast, sourceText)
if err != nil {
return "", errors.Wrap(err, "generate podcast transcript")
}
audioKey := audioKey(transcript, "wav")
url, err = r.Dependencies().ObjectStorage.Get(ctx, audioKey)
switch {
case err == nil:
// May canceled at last time by reload, fast return.
return url, nil
case errors.Is(err, object.ErrNotFound):
// Not found, generate new audio.
default:
return "", errors.Wrap(err, "get audio")
}
audioStream, err := r.generateAudio(ctx, toPodcast, transcript)
if err != nil {
return "", errors.Wrap(err, "generate podcast audio")
}
defer func() {
if closeErr := audioStream.Close(); closeErr != nil {
err = errors.Wrap(err, "close audio stream")
}
}()
url, err = r.Dependencies().ObjectStorage.Put(ctx, audioKey, audioStream, "audio/wav")
if err != nil {
return "", errors.Wrap(err, "store podcast audio")
}
return url, nil
}
func (r *rewriter) generateTranscript(ctx context.Context, toPodcast *ToPodcast, sourceText string) (string, error) {
transcript, err := r.Dependencies().LLMFactory.Get(toPodcast.LLM).
String(ctx, []string{toPodcast.transcriptPrompt, sourceText})
if err != nil {
return "", errors.Wrap(err, "llm completion")
}
return toPodcast.speakersDesc +
"\n\nFollowed by the actual dialogue script:\n" +
transcript, nil
}
func (r *rewriter) generateAudio(ctx context.Context, toPodcast *ToPodcast, transcript string) (io.ReadCloser, error) {
audioStream, err := r.Dependencies().LLMFactory.Get(toPodcast.TTSLLM).
WAV(ctx, transcript, toPodcast.speakers)
if err != nil {
return nil, errors.Wrap(err, "calling tts llm")
}
return audioStream, nil
}
type mockRewriter struct {
component.Mock
}

View File

@@ -2,6 +2,8 @@ package rewrite
import (
"context"
"io"
"strings"
"testing"
. "github.com/onsi/gomega"
@@ -12,6 +14,7 @@ import (
"github.com/glidea/zenfeed/pkg/component"
"github.com/glidea/zenfeed/pkg/llm"
"github.com/glidea/zenfeed/pkg/model"
"github.com/glidea/zenfeed/pkg/storage/object"
"github.com/glidea/zenfeed/pkg/test"
)
@@ -19,8 +22,9 @@ func TestLabels(t *testing.T) {
RegisterTestingT(t)
type givenDetail struct {
config *Config
llmMock func(m *mock.Mock)
config *Config
llmMock func(m *mock.Mock)
objectStorageMock func(m *mock.Mock)
}
type whenDetail struct {
inputLabels model.Labels
@@ -173,7 +177,7 @@ func TestLabels(t *testing.T) {
},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform text: llm completion: LLM failed"),
err: errors.New("transform: llm completion: LLM failed"),
isErr: true,
},
},
@@ -220,22 +224,163 @@ func TestLabels(t *testing.T) {
isErr: false,
},
},
{
Scenario: "Successfully generate podcast from content",
Given: "a rule to convert content to a podcast with all dependencies mocked to succeed",
When: "processing labels with content to be converted to a podcast",
Then: "should return labels with a new podcast_url label",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{
LLM: "mock-llm-transcript",
TTSLLM: "mock-llm-tts",
Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}},
},
},
Action: ActionCreateOrUpdateLabel,
Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("This is the podcast script.", nil).Once()
m.On("WAV", mock.Anything, mock.Anything, mock.AnythingOfType("[]llm.Speaker")).
Return(io.NopCloser(strings.NewReader("fake audio data")), nil).Once()
},
objectStorageMock: func(m *mock.Mock) {
m.On("Put", mock.Anything, mock.AnythingOfType("string"), mock.Anything, "audio/wav").
Return("http://storage.example.com/podcast.wav", nil).Once()
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
},
},
WhenDetail: whenDetail{
inputLabels: model.Labels{
{Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."},
},
},
ThenExpected: thenExpected{
outputLabels: model.Labels{
{Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."},
{Key: "podcast_url", Value: "http://storage.example.com/podcast.wav"},
},
isErr: false,
},
},
{
Scenario: "Fail podcast generation due to transcription LLM error",
Given: "a rule to convert content to a podcast, but the transcription LLM is mocked to fail",
When: "processing labels",
Then: "should return an error related to transcription failure",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
},
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("", errors.New("transcript failed")).Once()
},
},
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform: generate podcast transcript: llm completion: transcript failed"),
isErr: true,
},
},
{
Scenario: "Fail podcast generation due to TTS LLM error",
Given: "a rule to convert content to a podcast, but the TTS LLM is mocked to fail",
When: "processing labels",
Then: "should return an error related to TTS failure",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
},
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once()
m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("tts failed")).Once()
},
objectStorageMock: func(m *mock.Mock) {
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
},
},
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform: generate podcast audio: calling tts llm: tts failed"),
isErr: true,
},
},
{
Scenario: "Fail podcast generation due to object storage error",
Given: "a rule to convert content to a podcast, but object storage is mocked to fail",
When: "processing labels",
Then: "should return an error related to storage failure",
GivenDetail: givenDetail{
config: &Config{
{
SourceLabel: model.LabelContent,
Transform: &Transform{
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
},
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
},
},
llmMock: func(m *mock.Mock) {
m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once()
m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(io.NopCloser(strings.NewReader("fake audio")), nil).Once()
},
objectStorageMock: func(m *mock.Mock) {
m.On("Put", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return("", errors.New("storage failed")).Once()
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
},
},
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
ThenExpected: thenExpected{
outputLabels: nil,
err: errors.New("transform: store podcast audio: storage failed"),
isErr: true,
},
},
}
for _, tt := range tests {
t.Run(tt.Scenario, func(t *testing.T) {
// Given.
var mockLLMFactory llm.Factory
var mockInstance *mock.Mock // Store the mock instance for assertion
// Create mock factory and capture the mock.Mock instance.
mockOption := component.MockOption(func(m *mock.Mock) {
mockInstance = m // Capture the mock instance.
var mockLLMInstance *mock.Mock
llmMockOption := component.MockOption(func(m *mock.Mock) {
mockLLMInstance = m
if tt.GivenDetail.llmMock != nil {
tt.GivenDetail.llmMock(m)
}
})
mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, mockOption) // Use the factory directly with the option
mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, llmMockOption)
Expect(err).NotTo(HaveOccurred())
var mockObjectStorage object.Storage
var mockObjectStorageInstance *mock.Mock
objectStorageMockOption := component.MockOption(func(m *mock.Mock) {
mockObjectStorageInstance = m
if tt.GivenDetail.objectStorageMock != nil {
tt.GivenDetail.objectStorageMock(m)
}
})
mockObjectStorageFactory := object.NewFactory(objectStorageMockOption)
mockObjectStorage, err = mockObjectStorageFactory.New("test", nil, object.Dependencies{})
Expect(err).NotTo(HaveOccurred())
// Manually validate config to compile regex and render templates.
@@ -252,7 +397,8 @@ func TestLabels(t *testing.T) {
Instance: "test",
Config: tt.GivenDetail.config,
Dependencies: Dependencies{
LLMFactory: mockLLMFactory, // Pass the mock factory
LLMFactory: mockLLMFactory, // Pass the mock factory
ObjectStorage: mockObjectStorage,
},
}),
}
@@ -280,10 +426,12 @@ func TestLabels(t *testing.T) {
Expect(outputLabels).To(Equal(tt.ThenExpected.outputLabels))
}
// Verify LLM calls if stubs were provided.
if tt.GivenDetail.llmMock != nil && mockInstance != nil {
// Assert expectations on the captured mock instance.
mockInstance.AssertExpectations(t)
// Verify mock calls if stubs were provided.
if tt.GivenDetail.llmMock != nil && mockLLMInstance != nil {
mockLLMInstance.AssertExpectations(t)
}
if tt.GivenDetail.objectStorageMock != nil && mockObjectStorageInstance != nil {
mockObjectStorageInstance.AssertExpectations(t)
}
})
}