add podcast
This commit is contained in:
@@ -17,9 +17,13 @@ package rewrite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -30,10 +34,12 @@ import (
|
||||
"github.com/glidea/zenfeed/pkg/llm"
|
||||
"github.com/glidea/zenfeed/pkg/llm/prompt"
|
||||
"github.com/glidea/zenfeed/pkg/model"
|
||||
"github.com/glidea/zenfeed/pkg/storage/object"
|
||||
"github.com/glidea/zenfeed/pkg/telemetry"
|
||||
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
|
||||
"github.com/glidea/zenfeed/pkg/util/buffer"
|
||||
"github.com/glidea/zenfeed/pkg/util/crawl"
|
||||
hashutil "github.com/glidea/zenfeed/pkg/util/hash"
|
||||
)
|
||||
|
||||
// --- Interface code block ---
|
||||
@@ -68,7 +74,8 @@ func (c *Config) From(app *config.App) {
|
||||
}
|
||||
|
||||
type Dependencies struct {
|
||||
LLMFactory llm.Factory
|
||||
LLMFactory llm.Factory // NOTE: String() with cache.
|
||||
ObjectStorage object.Storage
|
||||
}
|
||||
|
||||
type Rule struct {
|
||||
@@ -120,32 +127,98 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen
|
||||
}
|
||||
|
||||
// Transform.
|
||||
if r.Transform != nil {
|
||||
if r.Transform.ToText == nil {
|
||||
return errors.New("to_text is required when transform is set")
|
||||
if r.Transform != nil { //nolint:nestif
|
||||
if r.Transform.ToText != nil && r.Transform.ToPodcast != nil {
|
||||
return errors.New("to_text and to_podcast can not be set at same time")
|
||||
}
|
||||
if r.Transform.ToText == nil && r.Transform.ToPodcast == nil {
|
||||
return errors.New("either to_text or to_podcast must be set when transform is set")
|
||||
}
|
||||
|
||||
switch r.Transform.ToText.Type {
|
||||
case ToTextTypePrompt:
|
||||
if r.Transform.ToText.Prompt == "" {
|
||||
return errors.New("to text prompt is required for prompt type")
|
||||
if r.Transform.ToText != nil {
|
||||
switch r.Transform.ToText.Type {
|
||||
case ToTextTypePrompt:
|
||||
if r.Transform.ToText.Prompt == "" {
|
||||
return errors.New("to text prompt is required for prompt type")
|
||||
}
|
||||
tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt)
|
||||
}
|
||||
|
||||
buf := buffer.Get()
|
||||
defer buffer.Put(buf)
|
||||
if err := tmpl.Execute(buf, prompt.Builtin); err != nil {
|
||||
return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt)
|
||||
}
|
||||
r.Transform.ToText.promptRendered = buf.String()
|
||||
|
||||
case ToTextTypeCrawl, ToTextTypeCrawlByJina:
|
||||
// No specific validation for crawl type here, as the source text itself is the URL.
|
||||
default:
|
||||
return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type)
|
||||
}
|
||||
tmpl, err := template.New("").Parse(r.Transform.ToText.Prompt)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "parse prompt template %s", r.Transform.ToText.Prompt)
|
||||
}
|
||||
|
||||
if r.Transform.ToPodcast != nil {
|
||||
if len(r.Transform.ToPodcast.Speakers) == 0 {
|
||||
return errors.New("at least one speaker is required for to_podcast")
|
||||
}
|
||||
|
||||
buf := buffer.Get()
|
||||
defer buffer.Put(buf)
|
||||
if err := tmpl.Execute(buf, prompt.Builtin); err != nil {
|
||||
return errors.Wrapf(err, "execute prompt template %s", r.Transform.ToText.Prompt)
|
||||
}
|
||||
r.Transform.ToText.promptRendered = buf.String()
|
||||
r.Transform.ToPodcast.speakers = make([]llm.Speaker, len(r.Transform.ToPodcast.Speakers))
|
||||
var speakerDescs []string
|
||||
var speakerNames []string
|
||||
for i, s := range r.Transform.ToPodcast.Speakers {
|
||||
if s.Name == "" {
|
||||
return errors.New("speaker name is required")
|
||||
}
|
||||
if s.Voice == "" {
|
||||
return errors.New("speaker voice is required")
|
||||
}
|
||||
r.Transform.ToPodcast.speakers[i] = llm.Speaker{Name: s.Name, Voice: s.Voice}
|
||||
|
||||
case ToTextTypeCrawl, ToTextTypeCrawlByJina:
|
||||
// No specific validation for crawl type here, as the source text itself is the URL.
|
||||
default:
|
||||
return errors.Errorf("unknown transform type: %s", r.Transform.ToText.Type)
|
||||
desc := s.Name
|
||||
if s.Role != "" {
|
||||
desc += " (" + s.Role + ")"
|
||||
}
|
||||
speakerDescs = append(speakerDescs, desc)
|
||||
speakerNames = append(speakerNames, s.Name)
|
||||
}
|
||||
|
||||
speakersDesc := "- " + strings.Join(speakerDescs, "\n- ")
|
||||
exampleSpeaker1 := speakerNames[0]
|
||||
exampleSpeaker2 := exampleSpeaker1
|
||||
if len(speakerNames) > 1 {
|
||||
exampleSpeaker2 = speakerNames[1]
|
||||
}
|
||||
|
||||
promptSegments := []string{
|
||||
"Please convert the following article into a podcast dialogue script.",
|
||||
"The speakers are:\n" + speakersDesc,
|
||||
}
|
||||
|
||||
if r.Transform.ToPodcast.EstimateMaximumDuration > 0 {
|
||||
wordsPerMinute := 200
|
||||
totalMinutes := int(r.Transform.ToPodcast.EstimateMaximumDuration.Minutes())
|
||||
estimatedWords := totalMinutes * wordsPerMinute
|
||||
promptSegments = append(promptSegments, fmt.Sprintf("The script should be approximately %d words to fit within a %d-minute duration. If the original content is not sufficient, the script can be shorter as appropriate.", estimatedWords, totalMinutes))
|
||||
}
|
||||
|
||||
if r.Transform.ToPodcast.TranscriptAdditionalPrompt != "" {
|
||||
promptSegments = append(promptSegments, "Additional instructions: "+r.Transform.ToPodcast.TranscriptAdditionalPrompt)
|
||||
}
|
||||
|
||||
promptSegments = append(promptSegments,
|
||||
"The output format MUST be a script where each line starts with the speaker's name followed by a colon and a space.",
|
||||
"Do NOT include any other text, explanations, or formatting before or after the script.",
|
||||
"Do NOT use background music in the script.",
|
||||
"Do NOT include any greetings or farewells (e.g., 'Hello everyone', 'Welcome to our show', 'Goodbye').",
|
||||
fmt.Sprintf("Example of the required format:\n%s: Today we are discussing the article's main points.\n%s: Let's start with the first one.", exampleSpeaker1, exampleSpeaker2),
|
||||
"Now, convert the article.",
|
||||
)
|
||||
|
||||
r.Transform.ToPodcast.transcriptPrompt = strings.Join(promptSegments, "\n\n")
|
||||
r.Transform.ToPodcast.speakersDesc = speakersDesc
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,9 +233,10 @@ func (r *Rule) Validate() error { //nolint:cyclop,gocognit,funlen
|
||||
r.matchRE = re
|
||||
|
||||
// Action.
|
||||
switch r.Action {
|
||||
case "":
|
||||
if r.Action == "" {
|
||||
r.Action = ActionCreateOrUpdateLabel
|
||||
}
|
||||
switch r.Action {
|
||||
case ActionCreateOrUpdateLabel:
|
||||
if r.Label == "" {
|
||||
return errors.New("label is required for create or update label action")
|
||||
@@ -179,7 +253,7 @@ func (r *Rule) From(c *config.RewriteRule) {
|
||||
r.If = c.If
|
||||
r.SourceLabel = c.SourceLabel
|
||||
r.SkipTooShortThreshold = c.SkipTooShortThreshold
|
||||
if c.Transform != nil {
|
||||
if c.Transform != nil { //nolint:nestif
|
||||
t := &Transform{}
|
||||
if c.Transform.ToText != nil {
|
||||
toText := &ToText{
|
||||
@@ -192,6 +266,25 @@ func (r *Rule) From(c *config.RewriteRule) {
|
||||
}
|
||||
t.ToText = toText
|
||||
}
|
||||
if c.Transform.ToPodcast != nil {
|
||||
toPodcast := &ToPodcast{
|
||||
LLM: c.Transform.ToPodcast.LLM,
|
||||
EstimateMaximumDuration: time.Duration(c.Transform.ToPodcast.EstimateMaximumDuration),
|
||||
TranscriptAdditionalPrompt: c.Transform.ToPodcast.TranscriptAdditionalPrompt,
|
||||
TTSLLM: c.Transform.ToPodcast.TTSLLM,
|
||||
}
|
||||
if toPodcast.EstimateMaximumDuration == 0 {
|
||||
toPodcast.EstimateMaximumDuration = 3 * time.Minute
|
||||
}
|
||||
for _, s := range c.Transform.ToPodcast.Speakers {
|
||||
toPodcast.Speakers = append(toPodcast.Speakers, Speaker{
|
||||
Name: s.Name,
|
||||
Role: s.Role,
|
||||
Voice: s.Voice,
|
||||
})
|
||||
}
|
||||
t.ToPodcast = toPodcast
|
||||
}
|
||||
r.Transform = t
|
||||
}
|
||||
r.Match = c.Match
|
||||
@@ -203,7 +296,8 @@ func (r *Rule) From(c *config.RewriteRule) {
|
||||
}
|
||||
|
||||
type Transform struct {
|
||||
ToText *ToText
|
||||
ToText *ToText
|
||||
ToPodcast *ToPodcast
|
||||
}
|
||||
|
||||
type ToText struct {
|
||||
@@ -220,6 +314,24 @@ type ToText struct {
|
||||
promptRendered string
|
||||
}
|
||||
|
||||
type ToPodcast struct {
|
||||
LLM string
|
||||
EstimateMaximumDuration time.Duration
|
||||
TranscriptAdditionalPrompt string
|
||||
TTSLLM string
|
||||
Speakers []Speaker
|
||||
|
||||
transcriptPrompt string
|
||||
speakersDesc string
|
||||
speakers []llm.Speaker
|
||||
}
|
||||
|
||||
type Speaker struct {
|
||||
Name string
|
||||
Role string
|
||||
Voice string
|
||||
}
|
||||
|
||||
type ToTextType string
|
||||
|
||||
const (
|
||||
@@ -310,13 +422,9 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m
|
||||
}
|
||||
|
||||
// Transform text if configured.
|
||||
text := sourceText
|
||||
if rule.Transform != nil && rule.Transform.ToText != nil {
|
||||
transformed, err := r.transformText(ctx, rule.Transform, sourceText)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "transform text")
|
||||
}
|
||||
text = transformed
|
||||
text, err := r.transform(ctx, rule.Transform, sourceText)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "transform")
|
||||
}
|
||||
|
||||
// Check if text matches the rule.
|
||||
@@ -338,18 +446,34 @@ func (r *rewriter) Labels(ctx context.Context, labels model.Labels) (rewritten m
|
||||
return labels, nil
|
||||
}
|
||||
|
||||
func (r *rewriter) transform(ctx context.Context, transform *Transform, sourceText string) (string, error) {
|
||||
if transform == nil {
|
||||
return sourceText, nil
|
||||
}
|
||||
|
||||
if transform.ToText != nil {
|
||||
return r.transformText(ctx, transform.ToText, sourceText)
|
||||
}
|
||||
|
||||
if transform.ToPodcast != nil {
|
||||
return r.transformPodcast(ctx, transform.ToPodcast, sourceText)
|
||||
}
|
||||
|
||||
return sourceText, nil
|
||||
}
|
||||
|
||||
// transformText transforms text using configured LLM or by crawling a URL.
|
||||
func (r *rewriter) transformText(ctx context.Context, transform *Transform, text string) (string, error) {
|
||||
switch transform.ToText.Type {
|
||||
func (r *rewriter) transformText(ctx context.Context, toText *ToText, text string) (string, error) {
|
||||
switch toText.Type {
|
||||
case ToTextTypeCrawl:
|
||||
return r.transformTextCrawl(ctx, r.crawler, text)
|
||||
case ToTextTypeCrawlByJina:
|
||||
return r.transformTextCrawl(ctx, r.jinaCrawler, text)
|
||||
|
||||
case ToTextTypePrompt:
|
||||
return r.transformTextPrompt(ctx, transform, text)
|
||||
return r.transformTextPrompt(ctx, toText, text)
|
||||
default:
|
||||
return r.transformTextPrompt(ctx, transform, text)
|
||||
return r.transformTextPrompt(ctx, toText, text)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,13 +487,13 @@ func (r *rewriter) transformTextCrawl(ctx context.Context, crawler crawl.Crawler
|
||||
}
|
||||
|
||||
// transformTextPrompt transforms text using configured LLM.
|
||||
func (r *rewriter) transformTextPrompt(ctx context.Context, transform *Transform, text string) (string, error) {
|
||||
func (r *rewriter) transformTextPrompt(ctx context.Context, toText *ToText, text string) (string, error) {
|
||||
// Get LLM instance.
|
||||
llm := r.Dependencies().LLMFactory.Get(transform.ToText.LLM)
|
||||
llm := r.Dependencies().LLMFactory.Get(toText.LLM)
|
||||
|
||||
// Call completion.
|
||||
result, err := llm.String(ctx, []string{
|
||||
transform.ToText.promptRendered,
|
||||
toText.promptRendered,
|
||||
text, // TODO: may place to first line to hit the model cache in different rewrite rules.
|
||||
})
|
||||
if err != nil {
|
||||
@@ -388,6 +512,71 @@ func (r *rewriter) transformTextHack(text string) string {
|
||||
return text
|
||||
}
|
||||
|
||||
var audioKey = func(transcript, ext string) string {
|
||||
hash := hashutil.Sum64(transcript)
|
||||
file := strconv.FormatUint(hash, 10) + "." + ext
|
||||
|
||||
return "podcasts/" + file
|
||||
}
|
||||
|
||||
func (r *rewriter) transformPodcast(ctx context.Context, toPodcast *ToPodcast, sourceText string) (url string, err error) {
|
||||
transcript, err := r.generateTranscript(ctx, toPodcast, sourceText)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "generate podcast transcript")
|
||||
}
|
||||
|
||||
audioKey := audioKey(transcript, "wav")
|
||||
url, err = r.Dependencies().ObjectStorage.Get(ctx, audioKey)
|
||||
switch {
|
||||
case err == nil:
|
||||
// May canceled at last time by reload, fast return.
|
||||
return url, nil
|
||||
case errors.Is(err, object.ErrNotFound):
|
||||
// Not found, generate new audio.
|
||||
default:
|
||||
return "", errors.Wrap(err, "get audio")
|
||||
}
|
||||
|
||||
audioStream, err := r.generateAudio(ctx, toPodcast, transcript)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "generate podcast audio")
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := audioStream.Close(); closeErr != nil {
|
||||
err = errors.Wrap(err, "close audio stream")
|
||||
}
|
||||
}()
|
||||
|
||||
url, err = r.Dependencies().ObjectStorage.Put(ctx, audioKey, audioStream, "audio/wav")
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "store podcast audio")
|
||||
}
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
func (r *rewriter) generateTranscript(ctx context.Context, toPodcast *ToPodcast, sourceText string) (string, error) {
|
||||
transcript, err := r.Dependencies().LLMFactory.Get(toPodcast.LLM).
|
||||
String(ctx, []string{toPodcast.transcriptPrompt, sourceText})
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "llm completion")
|
||||
}
|
||||
|
||||
return toPodcast.speakersDesc +
|
||||
"\n\nFollowed by the actual dialogue script:\n" +
|
||||
transcript, nil
|
||||
}
|
||||
|
||||
func (r *rewriter) generateAudio(ctx context.Context, toPodcast *ToPodcast, transcript string) (io.ReadCloser, error) {
|
||||
audioStream, err := r.Dependencies().LLMFactory.Get(toPodcast.TTSLLM).
|
||||
WAV(ctx, transcript, toPodcast.speakers)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "calling tts llm")
|
||||
}
|
||||
|
||||
return audioStream, nil
|
||||
}
|
||||
|
||||
type mockRewriter struct {
|
||||
component.Mock
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package rewrite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -12,6 +14,7 @@ import (
|
||||
"github.com/glidea/zenfeed/pkg/component"
|
||||
"github.com/glidea/zenfeed/pkg/llm"
|
||||
"github.com/glidea/zenfeed/pkg/model"
|
||||
"github.com/glidea/zenfeed/pkg/storage/object"
|
||||
"github.com/glidea/zenfeed/pkg/test"
|
||||
)
|
||||
|
||||
@@ -19,8 +22,9 @@ func TestLabels(t *testing.T) {
|
||||
RegisterTestingT(t)
|
||||
|
||||
type givenDetail struct {
|
||||
config *Config
|
||||
llmMock func(m *mock.Mock)
|
||||
config *Config
|
||||
llmMock func(m *mock.Mock)
|
||||
objectStorageMock func(m *mock.Mock)
|
||||
}
|
||||
type whenDetail struct {
|
||||
inputLabels model.Labels
|
||||
@@ -173,7 +177,7 @@ func TestLabels(t *testing.T) {
|
||||
},
|
||||
ThenExpected: thenExpected{
|
||||
outputLabels: nil,
|
||||
err: errors.New("transform text: llm completion: LLM failed"),
|
||||
err: errors.New("transform: llm completion: LLM failed"),
|
||||
isErr: true,
|
||||
},
|
||||
},
|
||||
@@ -220,22 +224,163 @@ func TestLabels(t *testing.T) {
|
||||
isErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "Successfully generate podcast from content",
|
||||
Given: "a rule to convert content to a podcast with all dependencies mocked to succeed",
|
||||
When: "processing labels with content to be converted to a podcast",
|
||||
Then: "should return labels with a new podcast_url label",
|
||||
GivenDetail: givenDetail{
|
||||
config: &Config{
|
||||
{
|
||||
SourceLabel: model.LabelContent,
|
||||
Transform: &Transform{
|
||||
ToPodcast: &ToPodcast{
|
||||
LLM: "mock-llm-transcript",
|
||||
TTSLLM: "mock-llm-tts",
|
||||
Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}},
|
||||
},
|
||||
},
|
||||
Action: ActionCreateOrUpdateLabel,
|
||||
Label: "podcast_url",
|
||||
},
|
||||
},
|
||||
llmMock: func(m *mock.Mock) {
|
||||
m.On("String", mock.Anything, mock.Anything).Return("This is the podcast script.", nil).Once()
|
||||
m.On("WAV", mock.Anything, mock.Anything, mock.AnythingOfType("[]llm.Speaker")).
|
||||
Return(io.NopCloser(strings.NewReader("fake audio data")), nil).Once()
|
||||
},
|
||||
objectStorageMock: func(m *mock.Mock) {
|
||||
m.On("Put", mock.Anything, mock.AnythingOfType("string"), mock.Anything, "audio/wav").
|
||||
Return("http://storage.example.com/podcast.wav", nil).Once()
|
||||
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
|
||||
},
|
||||
},
|
||||
WhenDetail: whenDetail{
|
||||
inputLabels: model.Labels{
|
||||
{Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."},
|
||||
},
|
||||
},
|
||||
ThenExpected: thenExpected{
|
||||
outputLabels: model.Labels{
|
||||
{Key: model.LabelContent, Value: "This is a long article to be converted into a podcast."},
|
||||
{Key: "podcast_url", Value: "http://storage.example.com/podcast.wav"},
|
||||
},
|
||||
isErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "Fail podcast generation due to transcription LLM error",
|
||||
Given: "a rule to convert content to a podcast, but the transcription LLM is mocked to fail",
|
||||
When: "processing labels",
|
||||
Then: "should return an error related to transcription failure",
|
||||
GivenDetail: givenDetail{
|
||||
config: &Config{
|
||||
{
|
||||
SourceLabel: model.LabelContent,
|
||||
Transform: &Transform{
|
||||
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
|
||||
},
|
||||
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
|
||||
},
|
||||
},
|
||||
llmMock: func(m *mock.Mock) {
|
||||
m.On("String", mock.Anything, mock.Anything).Return("", errors.New("transcript failed")).Once()
|
||||
},
|
||||
},
|
||||
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
|
||||
ThenExpected: thenExpected{
|
||||
outputLabels: nil,
|
||||
err: errors.New("transform: generate podcast transcript: llm completion: transcript failed"),
|
||||
isErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "Fail podcast generation due to TTS LLM error",
|
||||
Given: "a rule to convert content to a podcast, but the TTS LLM is mocked to fail",
|
||||
When: "processing labels",
|
||||
Then: "should return an error related to TTS failure",
|
||||
GivenDetail: givenDetail{
|
||||
config: &Config{
|
||||
{
|
||||
SourceLabel: model.LabelContent,
|
||||
Transform: &Transform{
|
||||
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
|
||||
},
|
||||
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
|
||||
},
|
||||
},
|
||||
llmMock: func(m *mock.Mock) {
|
||||
m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once()
|
||||
m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("tts failed")).Once()
|
||||
},
|
||||
objectStorageMock: func(m *mock.Mock) {
|
||||
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
|
||||
},
|
||||
},
|
||||
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
|
||||
ThenExpected: thenExpected{
|
||||
outputLabels: nil,
|
||||
err: errors.New("transform: generate podcast audio: calling tts llm: tts failed"),
|
||||
isErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "Fail podcast generation due to object storage error",
|
||||
Given: "a rule to convert content to a podcast, but object storage is mocked to fail",
|
||||
When: "processing labels",
|
||||
Then: "should return an error related to storage failure",
|
||||
GivenDetail: givenDetail{
|
||||
config: &Config{
|
||||
{
|
||||
SourceLabel: model.LabelContent,
|
||||
Transform: &Transform{
|
||||
ToPodcast: &ToPodcast{LLM: "mock-llm-transcript", TTSLLM: "mock-llm-tts", Speakers: []Speaker{{Name: "narrator", Voice: "alloy"}}},
|
||||
},
|
||||
Action: ActionCreateOrUpdateLabel, Label: "podcast_url",
|
||||
},
|
||||
},
|
||||
llmMock: func(m *mock.Mock) {
|
||||
m.On("String", mock.Anything, mock.Anything).Return("script", nil).Once()
|
||||
m.On("WAV", mock.Anything, mock.Anything, mock.Anything).Return(io.NopCloser(strings.NewReader("fake audio")), nil).Once()
|
||||
},
|
||||
objectStorageMock: func(m *mock.Mock) {
|
||||
m.On("Put", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return("", errors.New("storage failed")).Once()
|
||||
m.On("Get", mock.Anything, mock.AnythingOfType("string")).Return("", object.ErrNotFound).Once()
|
||||
},
|
||||
},
|
||||
WhenDetail: whenDetail{inputLabels: model.Labels{{Key: model.LabelContent, Value: "article"}}},
|
||||
ThenExpected: thenExpected{
|
||||
outputLabels: nil,
|
||||
err: errors.New("transform: store podcast audio: storage failed"),
|
||||
isErr: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.Scenario, func(t *testing.T) {
|
||||
// Given.
|
||||
var mockLLMFactory llm.Factory
|
||||
var mockInstance *mock.Mock // Store the mock instance for assertion
|
||||
|
||||
// Create mock factory and capture the mock.Mock instance.
|
||||
mockOption := component.MockOption(func(m *mock.Mock) {
|
||||
mockInstance = m // Capture the mock instance.
|
||||
var mockLLMInstance *mock.Mock
|
||||
llmMockOption := component.MockOption(func(m *mock.Mock) {
|
||||
mockLLMInstance = m
|
||||
if tt.GivenDetail.llmMock != nil {
|
||||
tt.GivenDetail.llmMock(m)
|
||||
}
|
||||
})
|
||||
mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, mockOption) // Use the factory directly with the option
|
||||
mockLLMFactory, err := llm.NewFactory("", nil, llm.FactoryDependencies{}, llmMockOption)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
var mockObjectStorage object.Storage
|
||||
var mockObjectStorageInstance *mock.Mock
|
||||
objectStorageMockOption := component.MockOption(func(m *mock.Mock) {
|
||||
mockObjectStorageInstance = m
|
||||
if tt.GivenDetail.objectStorageMock != nil {
|
||||
tt.GivenDetail.objectStorageMock(m)
|
||||
}
|
||||
})
|
||||
mockObjectStorageFactory := object.NewFactory(objectStorageMockOption)
|
||||
mockObjectStorage, err = mockObjectStorageFactory.New("test", nil, object.Dependencies{})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Manually validate config to compile regex and render templates.
|
||||
@@ -252,7 +397,8 @@ func TestLabels(t *testing.T) {
|
||||
Instance: "test",
|
||||
Config: tt.GivenDetail.config,
|
||||
Dependencies: Dependencies{
|
||||
LLMFactory: mockLLMFactory, // Pass the mock factory
|
||||
LLMFactory: mockLLMFactory, // Pass the mock factory
|
||||
ObjectStorage: mockObjectStorage,
|
||||
},
|
||||
}),
|
||||
}
|
||||
@@ -280,10 +426,12 @@ func TestLabels(t *testing.T) {
|
||||
Expect(outputLabels).To(Equal(tt.ThenExpected.outputLabels))
|
||||
}
|
||||
|
||||
// Verify LLM calls if stubs were provided.
|
||||
if tt.GivenDetail.llmMock != nil && mockInstance != nil {
|
||||
// Assert expectations on the captured mock instance.
|
||||
mockInstance.AssertExpectations(t)
|
||||
// Verify mock calls if stubs were provided.
|
||||
if tt.GivenDetail.llmMock != nil && mockLLMInstance != nil {
|
||||
mockLLMInstance.AssertExpectations(t)
|
||||
}
|
||||
if tt.GivenDetail.objectStorageMock != nil && mockObjectStorageInstance != nil {
|
||||
mockObjectStorageInstance.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user