Files
zenfeed/pkg/llm/gemini.go
2025-07-09 13:58:41 +08:00

249 lines
7.1 KiB
Go

// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package llm
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"path/filepath"
"github.com/pkg/errors"
oai "github.com/sashabaranov/go-openai"
"github.com/glidea/zenfeed/pkg/component"
"github.com/glidea/zenfeed/pkg/telemetry"
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
"github.com/glidea/zenfeed/pkg/util/wav"
)
type gemini struct {
*component.Base[Config, struct{}]
text
hc *http.Client
embeddingSpliter embeddingSpliter
}
func newGemini(c *Config) LLM {
config := oai.DefaultConfig(c.APIKey)
config.BaseURL = filepath.Join(c.Endpoint, "openai") // OpenAI compatible endpoint.
client := oai.NewClientWithConfig(config)
embeddingSpliter := newEmbeddingSpliter(1536, 64)
base := component.New(&component.BaseConfig[Config, struct{}]{
Name: "LLM/gemini",
Instance: c.Name,
Config: c,
})
return &gemini{
Base: base,
text: &openaiText{
Base: base,
client: client,
},
hc: &http.Client{},
embeddingSpliter: embeddingSpliter,
}
}
func (g *gemini) WAV(ctx context.Context, text string, speakers []Speaker) (r io.ReadCloser, err error) {
ctx = telemetry.StartWith(ctx, append(g.TelemetryLabels(), telemetrymodel.KeyOperation, "WAV")...)
defer func() { telemetry.End(ctx, err) }()
if g.Config().TTSModel == "" {
return nil, errors.New("tts model is not set")
}
reqPayload, err := buildWAVRequestPayload(text, speakers)
if err != nil {
return nil, errors.Wrap(err, "build wav request payload")
}
pcmData, err := g.doWAVRequest(ctx, reqPayload)
if err != nil {
return nil, errors.Wrap(err, "do wav request")
}
return streamWAV(pcmData), nil
}
func (g *gemini) doWAVRequest(ctx context.Context, reqPayload *geminiRequest) ([]byte, error) {
config := g.Config()
body, err := json.Marshal(reqPayload)
if err != nil {
return nil, errors.Wrap(err, "marshal tts request")
}
url := config.Endpoint + "/models/" + config.TTSModel + ":generateContent"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "new tts request")
}
req.Header.Set("x-goog-api-key", config.APIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := g.hc.Do(req)
if err != nil {
return nil, errors.Wrap(err, "do tts request")
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
errMsg, _ := io.ReadAll(resp.Body)
return nil, errors.Errorf("tts request failed with status %d: %s", resp.StatusCode, string(errMsg))
}
var ttsResp geminiResponse
if err := json.NewDecoder(resp.Body).Decode(&ttsResp); err != nil {
return nil, errors.Wrap(err, "decode tts response")
}
if len(ttsResp.Candidates) == 0 || len(ttsResp.Candidates[0].Content.Parts) == 0 || ttsResp.Candidates[0].Content.Parts[0].InlineData == nil {
return nil, errors.New("no audio data in tts response")
}
audioDataB64 := ttsResp.Candidates[0].Content.Parts[0].InlineData.Data
pcmData, err := base64.StdEncoding.DecodeString(audioDataB64)
if err != nil {
return nil, errors.Wrap(err, "decode base64")
}
return pcmData, nil
}
func buildWAVRequestPayload(text string, speakers []Speaker) (*geminiRequest, error) {
reqPayload := geminiRequest{
Contents: []*geminiRequestContent{{Parts: []*geminiRequestPart{{Text: text}}}},
Config: &geminiRequestConfig{
ResponseModalities: []string{"AUDIO"},
SpeechConfig: &geminiRequestSpeechConfig{},
},
}
switch len(speakers) {
case 0:
return nil, errors.New("no speakers")
case 1:
reqPayload.Config.SpeechConfig.VoiceConfig = &geminiRequestVoiceConfig{
PrebuiltVoiceConfig: &geminiRequestPrebuiltVoiceConfig{VoiceName: speakers[0].Voice},
}
default:
multiSpeakerConfig := &geminiRequestMultiSpeakerVoiceConfig{}
for _, s := range speakers {
multiSpeakerConfig.SpeakerVoiceConfigs = append(multiSpeakerConfig.SpeakerVoiceConfigs, &geminiRequestSpeakerVoiceConfig{
Speaker: s.Name,
VoiceConfig: &geminiRequestVoiceConfig{
PrebuiltVoiceConfig: &geminiRequestPrebuiltVoiceConfig{VoiceName: s.Voice},
},
})
}
reqPayload.Config.SpeechConfig.MultiSpeakerVoiceConfig = multiSpeakerConfig
}
return &reqPayload, nil
}
func streamWAV(pcmData []byte) io.ReadCloser {
pipeReader, pipeWriter := io.Pipe()
go func() {
defer func() { _ = pipeWriter.Close() }()
if err := wav.WriteHeader(pipeWriter, geminiWavHeader, uint32(len(pcmData))); err != nil {
pipeWriter.CloseWithError(errors.Wrap(err, "write wav header"))
return
}
if _, err := io.Copy(pipeWriter, bytes.NewReader(pcmData)); err != nil {
pipeWriter.CloseWithError(errors.Wrap(err, "write pcm data"))
return
}
}()
return pipeReader
}
var geminiWavHeader = &wav.Header{
SampleRate: 24000,
BitDepth: 16,
NumChannels: 1,
}
type geminiRequest struct {
Contents []*geminiRequestContent `json:"contents"`
Config *geminiRequestConfig `json:"generationConfig"`
}
type geminiRequestContent struct {
Parts []*geminiRequestPart `json:"parts"`
}
type geminiRequestPart struct {
Text string `json:"text"`
}
type geminiRequestConfig struct {
ResponseModalities []string `json:"responseModalities"`
SpeechConfig *geminiRequestSpeechConfig `json:"speechConfig"`
}
type geminiRequestSpeechConfig struct {
VoiceConfig *geminiRequestVoiceConfig `json:"voiceConfig,omitempty"`
MultiSpeakerVoiceConfig *geminiRequestMultiSpeakerVoiceConfig `json:"multiSpeakerVoiceConfig,omitempty"`
}
type geminiRequestVoiceConfig struct {
PrebuiltVoiceConfig *geminiRequestPrebuiltVoiceConfig `json:"prebuiltVoiceConfig,omitempty"`
}
type geminiRequestPrebuiltVoiceConfig struct {
VoiceName string `json:"voiceName,omitempty"`
}
type geminiRequestMultiSpeakerVoiceConfig struct {
SpeakerVoiceConfigs []*geminiRequestSpeakerVoiceConfig `json:"speakerVoiceConfigs,omitempty"`
}
type geminiRequestSpeakerVoiceConfig struct {
Speaker string `json:"speaker,omitempty"`
VoiceConfig *geminiRequestVoiceConfig `json:"voiceConfig,omitempty"`
}
type geminiResponse struct {
Candidates []*geminiResponseCandidate `json:"candidates"`
}
type geminiResponseCandidate struct {
Content *geminiResponseContent `json:"content"`
}
type geminiResponseContent struct {
Parts []*geminiResponsePart `json:"parts"`
}
type geminiResponsePart struct {
InlineData *geminiResponseInlineData `json:"inlineData"`
}
type geminiResponseInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"` // Base64 encoded.
}