init
This commit is contained in:
131
pkg/llm/embedding_spliter.go
Normal file
131
pkg/llm/embedding_spliter.go
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright (C) 2025 wangyusong
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package llm
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/glidea/zenfeed/pkg/model"
|
||||
)
|
||||
|
||||
type embeddingSpliter interface {
|
||||
Split(ls model.Labels) ([]model.Labels, error)
|
||||
}
|
||||
|
||||
func newEmbeddingSpliter(maxLabelValueTokens, overlapTokens int) embeddingSpliter {
|
||||
if maxLabelValueTokens <= 0 {
|
||||
maxLabelValueTokens = 1024
|
||||
}
|
||||
if overlapTokens <= 0 {
|
||||
overlapTokens = 64
|
||||
}
|
||||
if overlapTokens > maxLabelValueTokens {
|
||||
overlapTokens = maxLabelValueTokens / 10
|
||||
}
|
||||
|
||||
return &embeddingSpliterImpl{maxLabelValueTokens: maxLabelValueTokens, overlapTokens: overlapTokens}
|
||||
}
|
||||
|
||||
type embeddingSpliterImpl struct {
|
||||
maxLabelValueTokens int
|
||||
overlapTokens int
|
||||
}
|
||||
|
||||
func (e *embeddingSpliterImpl) Split(ls model.Labels) ([]model.Labels, error) {
|
||||
var (
|
||||
short = make(model.Labels, 0, len(ls))
|
||||
long = make(model.Labels, 0, 1)
|
||||
longTokens = make([]int, 0, 1)
|
||||
)
|
||||
for _, l := range ls {
|
||||
tokens := e.estimateTokens(l.Value)
|
||||
if tokens <= e.maxLabelValueTokens {
|
||||
short = append(short, l)
|
||||
} else {
|
||||
long = append(long, l)
|
||||
longTokens = append(longTokens, tokens)
|
||||
}
|
||||
}
|
||||
if len(long) == 0 {
|
||||
return []model.Labels{ls}, nil
|
||||
}
|
||||
|
||||
var (
|
||||
common = short
|
||||
splits = make([]model.Labels, 0, len(long)*2)
|
||||
)
|
||||
for i := range long {
|
||||
parts := e.split(long[i].Value, longTokens[i])
|
||||
for _, p := range parts {
|
||||
com := slices.Clone(common)
|
||||
s := append(com, model.Label{Key: long[i].Key, Value: p})
|
||||
splits = append(splits, s)
|
||||
}
|
||||
}
|
||||
|
||||
return splits, nil
|
||||
}
|
||||
|
||||
func (e *embeddingSpliterImpl) split(value string, tokens int) []string {
|
||||
var (
|
||||
results = make([]string, 0)
|
||||
chars = []rune(value)
|
||||
)
|
||||
|
||||
// Estimate the number of characters per token
|
||||
avgCharsPerToken := float64(len(chars)) / float64(tokens)
|
||||
// Calculate the approximate number of characters corresponding to maxLabelValueTokens tokens.
|
||||
charsPerSegment := int(float64(e.maxLabelValueTokens) * avgCharsPerToken)
|
||||
|
||||
// The number of characters corresponding to a fixed overlap of 64 tokens.
|
||||
overlapChars := int(float64(e.overlapTokens) * avgCharsPerToken)
|
||||
|
||||
// Actual step length = segment length - overlap.
|
||||
charStep := charsPerSegment - overlapChars
|
||||
|
||||
for start := 0; start < len(chars); {
|
||||
end := min(start+charsPerSegment, len(chars))
|
||||
|
||||
segment := string(chars[start:end])
|
||||
results = append(results, segment)
|
||||
|
||||
if end == len(chars) {
|
||||
break
|
||||
}
|
||||
start += charStep
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func (e *embeddingSpliterImpl) estimateTokens(text string) int {
|
||||
latinChars := 0
|
||||
otherChars := 0
|
||||
|
||||
for _, r := range text {
|
||||
if r <= 127 {
|
||||
latinChars++
|
||||
} else {
|
||||
otherChars++
|
||||
}
|
||||
}
|
||||
|
||||
// Rough estimate:
|
||||
// - English and punctuation: about 0.25 tokens/char (4 characters ≈ 1 token).
|
||||
// - Chinese and other non-Latin characters: about 1.5 tokens/char.
|
||||
return int(math.Round(float64(latinChars)/4 + float64(otherChars)*3/2))
|
||||
}
|
||||
158
pkg/llm/embedding_spliter_test.go
Normal file
158
pkg/llm/embedding_spliter_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright (C) 2025 wangyusong
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package llm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/glidea/zenfeed/pkg/model"
|
||||
"github.com/glidea/zenfeed/pkg/test"
|
||||
)
|
||||
|
||||
func TestEmbeddingSpliter_Split(t *testing.T) {
|
||||
RegisterTestingT(t)
|
||||
|
||||
type givenDetail struct {
|
||||
maxLabelValueTokens int
|
||||
overlapTokens int
|
||||
}
|
||||
type whenDetail struct {
|
||||
labels model.Labels
|
||||
}
|
||||
type thenExpected struct {
|
||||
splits []model.Labels
|
||||
err string
|
||||
}
|
||||
|
||||
tests := []test.Case[givenDetail, whenDetail, thenExpected]{
|
||||
{
|
||||
Scenario: "Split labels with all short values",
|
||||
Given: "an embedding spliter with max token limit",
|
||||
When: "splitting labels with all values under token limit",
|
||||
Then: "should return original labels as single split",
|
||||
GivenDetail: givenDetail{
|
||||
maxLabelValueTokens: 1024,
|
||||
},
|
||||
WhenDetail: whenDetail{
|
||||
labels: model.Labels{
|
||||
{Key: "title", Value: "Short title"},
|
||||
{Key: "description", Value: "Short description"},
|
||||
},
|
||||
},
|
||||
ThenExpected: thenExpected{
|
||||
splits: []model.Labels{
|
||||
{
|
||||
{Key: "title", Value: "Short title"},
|
||||
{Key: "description", Value: "Short description"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "Split labels with one long value",
|
||||
Given: "an embedding spliter with max token limit",
|
||||
When: "splitting labels with one value exceeding token limit",
|
||||
Then: "should split the long value and combine with common labels",
|
||||
GivenDetail: givenDetail{
|
||||
maxLabelValueTokens: 10, // Small limit to force splitting.
|
||||
overlapTokens: 1,
|
||||
},
|
||||
WhenDetail: whenDetail{
|
||||
labels: model.Labels{
|
||||
{Key: "title", Value: "Short title"},
|
||||
{Key: "content", Value: "This is a long content that exceeds the token limit and needs to be split into multiple parts"},
|
||||
},
|
||||
},
|
||||
ThenExpected: thenExpected{
|
||||
splits: []model.Labels{
|
||||
{
|
||||
{Key: "title", Value: "Short title"},
|
||||
{Key: "content", Value: "This is a long content that exceeds the "},
|
||||
},
|
||||
{
|
||||
{Key: "title", Value: "Short title"},
|
||||
{Key: "content", Value: "the token limit and needs to be split in"},
|
||||
},
|
||||
{
|
||||
{Key: "title", Value: "Short title"},
|
||||
{Key: "content", Value: "t into multiple parts"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Scenario: "Handle non-Latin characters",
|
||||
Given: "an embedding spliter with max token limit",
|
||||
When: "splitting labels with non-Latin characters",
|
||||
Then: "should correctly estimate tokens and split accordingly",
|
||||
GivenDetail: givenDetail{
|
||||
maxLabelValueTokens: 10, // Small limit to force splitting.
|
||||
overlapTokens: 2,
|
||||
},
|
||||
WhenDetail: whenDetail{
|
||||
labels: model.Labels{
|
||||
{Key: "title", Value: "Short title"},
|
||||
{Key: "content", Value: "中文内容需要被分割因为它超过了令牌限制"}, // Chinese content that needs to be split.
|
||||
},
|
||||
},
|
||||
ThenExpected: thenExpected{
|
||||
splits: []model.Labels{
|
||||
{
|
||||
{Key: "title", Value: "Short title"},
|
||||
{Key: "content", Value: "中文内容需要"},
|
||||
},
|
||||
{
|
||||
{Key: "title", Value: "Short title"},
|
||||
{Key: "content", Value: "要被分割因为"},
|
||||
},
|
||||
{
|
||||
{Key: "title", Value: "Short title"},
|
||||
{Key: "content", Value: "为它超过了令"},
|
||||
},
|
||||
{
|
||||
{Key: "title", Value: "Short title"},
|
||||
{Key: "content", Value: "令牌限制"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.Scenario, func(t *testing.T) {
|
||||
// Given.
|
||||
spliter := newEmbeddingSpliter(tt.GivenDetail.maxLabelValueTokens, tt.GivenDetail.overlapTokens)
|
||||
|
||||
// When.
|
||||
splits, err := spliter.Split(tt.WhenDetail.labels)
|
||||
|
||||
// Then.
|
||||
if tt.ThenExpected.err != "" {
|
||||
Expect(err).NotTo(BeNil())
|
||||
Expect(err.Error()).To(ContainSubstring(tt.ThenExpected.err))
|
||||
} else {
|
||||
Expect(err).To(BeNil())
|
||||
Expect(len(splits)).To(Equal(len(tt.ThenExpected.splits)))
|
||||
|
||||
for i, expectedSplit := range tt.ThenExpected.splits {
|
||||
Expect(splits[i]).To(Equal(expectedSplit))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
420
pkg/llm/llm.go
Normal file
420
pkg/llm/llm.go
Normal file
@@ -0,0 +1,420 @@
|
||||
// Copyright (C) 2025 wangyusong
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/glidea/zenfeed/pkg/component"
|
||||
"github.com/glidea/zenfeed/pkg/config"
|
||||
"github.com/glidea/zenfeed/pkg/model"
|
||||
"github.com/glidea/zenfeed/pkg/storage/kv"
|
||||
"github.com/glidea/zenfeed/pkg/telemetry/log"
|
||||
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
|
||||
"github.com/glidea/zenfeed/pkg/util/hash"
|
||||
)
|
||||
|
||||
// --- Interface code block ---
|
||||
type LLM interface {
|
||||
component.Component
|
||||
String(ctx context.Context, messages []string) (string, error)
|
||||
EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error)
|
||||
Embedding(ctx context.Context, text string) ([]float32, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
Default bool
|
||||
Provider ProviderType
|
||||
Endpoint string
|
||||
APIKey string
|
||||
Model, EmbeddingModel string
|
||||
Temperature float32
|
||||
}
|
||||
|
||||
type ProviderType string
|
||||
|
||||
const (
|
||||
ProviderTypeOpenAI ProviderType = "openai"
|
||||
ProviderTypeOpenRouter ProviderType = "openrouter"
|
||||
ProviderTypeDeepSeek ProviderType = "deepseek"
|
||||
ProviderTypeGemini ProviderType = "gemini"
|
||||
ProviderTypeVolc ProviderType = "volc" // Rename MaaS to ARK. 😄
|
||||
ProviderTypeSiliconFlow ProviderType = "siliconflow"
|
||||
)
|
||||
|
||||
var defaultEndpoints = map[ProviderType]string{
|
||||
ProviderTypeOpenAI: "https://api.openai.com/v1",
|
||||
ProviderTypeOpenRouter: "https://openrouter.ai/api/v1",
|
||||
ProviderTypeDeepSeek: "https://api.deepseek.com/v1",
|
||||
ProviderTypeGemini: "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
ProviderTypeVolc: "https://ark.cn-beijing.volces.com/api/v3",
|
||||
ProviderTypeSiliconFlow: "https://api.siliconflow.cn/v1",
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error { //nolint:cyclop
|
||||
if c.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
|
||||
switch c.Provider {
|
||||
case "":
|
||||
c.Provider = ProviderTypeOpenAI
|
||||
case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek,
|
||||
ProviderTypeGemini, ProviderTypeVolc, ProviderTypeSiliconFlow:
|
||||
default:
|
||||
return errors.Errorf("invalid provider: %s", c.Provider)
|
||||
}
|
||||
|
||||
if c.Endpoint == "" {
|
||||
c.Endpoint = defaultEndpoints[c.Provider]
|
||||
}
|
||||
if c.APIKey == "" {
|
||||
return errors.New("api key is required")
|
||||
}
|
||||
if c.Model == "" && c.EmbeddingModel == "" {
|
||||
return errors.New("model or embedding model is required")
|
||||
}
|
||||
if c.Temperature < 0 || c.Temperature > 2 {
|
||||
return errors.Errorf("invalid temperature: %f, should be in range [0, 2]", c.Temperature)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
promptTokens = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: model.AppName,
|
||||
Subsystem: "llm",
|
||||
Name: "prompt_tokens",
|
||||
},
|
||||
[]string{telemetrymodel.KeyComponent, telemetrymodel.KeyComponentInstance, telemetrymodel.KeyOperation},
|
||||
)
|
||||
completionTokens = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: model.AppName,
|
||||
Subsystem: "llm",
|
||||
Name: "completion_tokens",
|
||||
},
|
||||
[]string{telemetrymodel.KeyComponent, telemetrymodel.KeyComponentInstance, telemetrymodel.KeyOperation},
|
||||
)
|
||||
totalTokens = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: model.AppName,
|
||||
Subsystem: "llm",
|
||||
Name: "total_tokens",
|
||||
},
|
||||
[]string{telemetrymodel.KeyComponent, telemetrymodel.KeyComponentInstance, telemetrymodel.KeyOperation},
|
||||
)
|
||||
)
|
||||
|
||||
// --- Factory code block ---
|
||||
type FactoryConfig struct {
|
||||
LLMs []Config
|
||||
defaultLLM string
|
||||
}
|
||||
|
||||
func (c *FactoryConfig) Validate() error {
|
||||
if len(c.LLMs) == 0 {
|
||||
return errors.New("no llm config")
|
||||
}
|
||||
|
||||
for i := range c.LLMs {
|
||||
if err := (&c.LLMs[i]).Validate(); err != nil {
|
||||
return errors.Wrapf(err, "validate llm config %s", c.LLMs[i].Name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.LLMs) == 1 {
|
||||
c.LLMs[0].Default = true
|
||||
c.defaultLLM = c.LLMs[0].Name
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
defaults := 0
|
||||
for _, llm := range c.LLMs {
|
||||
if llm.Default {
|
||||
c.defaultLLM = llm.Name
|
||||
defaults++
|
||||
}
|
||||
}
|
||||
if defaults > 1 {
|
||||
return errors.New("multiple llm configs are default")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *FactoryConfig) From(app *config.App) {
|
||||
for _, llm := range app.LLMs {
|
||||
c.LLMs = append(c.LLMs, Config{
|
||||
Name: llm.Name,
|
||||
Default: llm.Default,
|
||||
Provider: ProviderType(llm.Provider),
|
||||
Endpoint: llm.Endpoint,
|
||||
APIKey: llm.APIKey,
|
||||
Model: llm.Model,
|
||||
EmbeddingModel: llm.EmbeddingModel,
|
||||
Temperature: llm.Temperature,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type FactoryDependencies struct {
|
||||
KVStorage kv.Storage
|
||||
}
|
||||
|
||||
// Factory is a factory for creating LLM instances.
|
||||
// If name is empty or not found, it will return the default.
|
||||
type Factory interface {
|
||||
component.Component
|
||||
config.Watcher
|
||||
Get(name string) LLM
|
||||
}
|
||||
|
||||
func NewFactory(
|
||||
instance string,
|
||||
app *config.App,
|
||||
dependencies FactoryDependencies,
|
||||
mockOn ...component.MockOption,
|
||||
) (Factory, error) {
|
||||
if len(mockOn) > 0 {
|
||||
mf := &mockFactory{}
|
||||
getCall := mf.On("Get", mock.Anything)
|
||||
getCall.Run(func(args mock.Arguments) {
|
||||
m := &mockLLM{}
|
||||
component.MockOptions(mockOn).Apply(&m.Mock)
|
||||
getCall.Return(m, nil)
|
||||
})
|
||||
mf.On("Reload", mock.Anything).Return(nil)
|
||||
|
||||
return mf, nil
|
||||
}
|
||||
|
||||
config := &FactoryConfig{}
|
||||
config.From(app)
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, errors.Wrap(err, "validate config")
|
||||
}
|
||||
f := &factory{
|
||||
Base: component.New(&component.BaseConfig[FactoryConfig, FactoryDependencies]{
|
||||
Name: "LLMFactory",
|
||||
Instance: instance,
|
||||
Config: config,
|
||||
Dependencies: dependencies,
|
||||
}),
|
||||
llms: make(map[string]LLM),
|
||||
}
|
||||
f.initLLMs()
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
type factory struct {
|
||||
*component.Base[FactoryConfig, FactoryDependencies]
|
||||
|
||||
defaultLLM LLM
|
||||
llms map[string]LLM
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (f *factory) Run() error {
|
||||
for _, llm := range f.llms {
|
||||
if err := component.RunUntilReady(f.Context(), llm, 10*time.Second); err != nil {
|
||||
return errors.Wrapf(err, "run llm %s", llm.Name())
|
||||
}
|
||||
}
|
||||
f.MarkReady()
|
||||
<-f.Context().Done()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *factory) Close() error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
for _, llm := range f.llms {
|
||||
_ = llm.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *factory) Reload(app *config.App) error {
|
||||
newConfig := &FactoryConfig{}
|
||||
newConfig.From(app)
|
||||
if err := newConfig.Validate(); err != nil {
|
||||
return errors.Wrap(err, "validate config")
|
||||
}
|
||||
if reflect.DeepEqual(f.Config(), newConfig) {
|
||||
log.Debug(f.Context(), "no changes in llm config")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reload the LLMs.
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.SetConfig(newConfig)
|
||||
|
||||
// Close the old LLMs.
|
||||
for _, llm := range f.llms {
|
||||
_ = llm.Close()
|
||||
}
|
||||
|
||||
// Recreate the LLMs.
|
||||
f.initLLMs()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *factory) Get(name string) LLM {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if name == "" {
|
||||
return f.defaultLLM
|
||||
}
|
||||
|
||||
for _, llmC := range f.Config().LLMs {
|
||||
if llmC.Name != name {
|
||||
continue
|
||||
}
|
||||
|
||||
if f.llms[name] == nil {
|
||||
llm := f.new(&llmC)
|
||||
f.llms[name] = llm
|
||||
}
|
||||
|
||||
return f.llms[name]
|
||||
}
|
||||
|
||||
return f.defaultLLM
|
||||
}
|
||||
|
||||
func (f *factory) new(c *Config) LLM {
|
||||
switch c.Provider {
|
||||
case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek, ProviderTypeGemini, ProviderTypeVolc, ProviderTypeSiliconFlow: //nolint:lll
|
||||
return newCached(newOpenAI(c), f.Dependencies().KVStorage)
|
||||
default:
|
||||
return newCached(newOpenAI(c), f.Dependencies().KVStorage)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *factory) initLLMs() {
|
||||
var (
|
||||
config = f.Config()
|
||||
llms = make(map[string]LLM, len(config.LLMs))
|
||||
defaultLLM LLM
|
||||
)
|
||||
for _, llmC := range config.LLMs {
|
||||
llm := f.new(&llmC)
|
||||
llms[llmC.Name] = llm
|
||||
|
||||
if llmC.Name == config.defaultLLM {
|
||||
defaultLLM = llm
|
||||
}
|
||||
}
|
||||
f.llms = llms
|
||||
f.defaultLLM = defaultLLM
|
||||
}
|
||||
|
||||
type mockFactory struct {
|
||||
component.Mock
|
||||
}
|
||||
|
||||
func (m *mockFactory) Get(name string) LLM {
|
||||
args := m.Called(name)
|
||||
|
||||
return args.Get(0).(LLM)
|
||||
}
|
||||
|
||||
func (m *mockFactory) Reload(app *config.App) error {
|
||||
args := m.Called(app)
|
||||
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// --- Implementation code block ---
|
||||
type cached struct {
|
||||
LLM
|
||||
kvStorage kv.Storage
|
||||
}
|
||||
|
||||
func newCached(llm LLM, kvStorage kv.Storage) LLM {
|
||||
return &cached{
|
||||
LLM: llm,
|
||||
kvStorage: kvStorage,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cached) String(ctx context.Context, messages []string) (string, error) {
|
||||
key := hash.Sum64s(messages)
|
||||
keyStr := strconv.FormatUint(key, 10)
|
||||
|
||||
value, err := c.kvStorage.Get(ctx, keyStr)
|
||||
switch {
|
||||
case err == nil:
|
||||
return value, nil
|
||||
case errors.Is(err, kv.ErrNotFound):
|
||||
break
|
||||
default:
|
||||
return "", errors.Wrap(err, "get from kv storage")
|
||||
}
|
||||
|
||||
value, err = c.LLM.String(ctx, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err = c.kvStorage.Set(ctx, keyStr, value, 65*time.Minute); err != nil {
|
||||
log.Error(ctx, err, "set to kv storage")
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
type mockLLM struct {
|
||||
component.Mock
|
||||
}
|
||||
|
||||
func (m *mockLLM) String(ctx context.Context, messages []string) (string, error) {
|
||||
args := m.Called(ctx, messages)
|
||||
|
||||
return args.Get(0).(string), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockLLM) EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error) {
|
||||
args := m.Called(ctx, labels)
|
||||
|
||||
return args.Get(0).([][]float32), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockLLM) Embedding(ctx context.Context, text string) ([]float32, error) {
|
||||
args := m.Called(ctx, text)
|
||||
|
||||
return args.Get(0).([]float32), args.Error(1)
|
||||
}
|
||||
146
pkg/llm/openai.go
Normal file
146
pkg/llm/openai.go
Normal file
@@ -0,0 +1,146 @@
|
||||
// Copyright (C) 2025 wangyusong
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
oai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/glidea/zenfeed/pkg/component"
|
||||
"github.com/glidea/zenfeed/pkg/model"
|
||||
"github.com/glidea/zenfeed/pkg/telemetry"
|
||||
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
|
||||
runtimeutil "github.com/glidea/zenfeed/pkg/util/runtime"
|
||||
)
|
||||
|
||||
type openai struct {
|
||||
*component.Base[Config, struct{}]
|
||||
|
||||
client *oai.Client
|
||||
embeddingSpliter embeddingSpliter
|
||||
}
|
||||
|
||||
func newOpenAI(c *Config) LLM {
|
||||
config := oai.DefaultConfig(c.APIKey)
|
||||
config.BaseURL = c.Endpoint
|
||||
client := oai.NewClientWithConfig(config)
|
||||
embeddingSpliter := newEmbeddingSpliter(2048, 64)
|
||||
|
||||
return &openai{
|
||||
Base: component.New(&component.BaseConfig[Config, struct{}]{
|
||||
Name: "LLM/openai",
|
||||
Instance: c.Name,
|
||||
Config: c,
|
||||
}),
|
||||
client: client,
|
||||
embeddingSpliter: embeddingSpliter,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *openai) String(ctx context.Context, messages []string) (value string, err error) {
|
||||
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "String")...)
|
||||
defer func() { telemetry.End(ctx, err) }()
|
||||
|
||||
config := o.Config()
|
||||
if config.Model == "" {
|
||||
return "", errors.New("model is not set")
|
||||
}
|
||||
msg := make([]oai.ChatCompletionMessage, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
msg = append(msg, oai.ChatCompletionMessage{
|
||||
Role: oai.ChatMessageRoleUser,
|
||||
Content: m,
|
||||
})
|
||||
}
|
||||
|
||||
req := oai.ChatCompletionRequest{
|
||||
Model: config.Model,
|
||||
Messages: msg,
|
||||
Temperature: config.Temperature,
|
||||
}
|
||||
|
||||
resp, err := o.client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "create chat completion")
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", errors.New("no completion choices returned")
|
||||
}
|
||||
|
||||
lvs := []string{o.Name(), o.Instance(), "String"}
|
||||
promptTokens.WithLabelValues(lvs...).Add(float64(resp.Usage.PromptTokens))
|
||||
completionTokens.WithLabelValues(lvs...).Add(float64(resp.Usage.CompletionTokens))
|
||||
totalTokens.WithLabelValues(lvs...).Add(float64(resp.Usage.TotalTokens))
|
||||
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func (o *openai) EmbeddingLabels(ctx context.Context, labels model.Labels) (value [][]float32, err error) {
|
||||
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "EmbeddingLabels")...)
|
||||
defer func() { telemetry.End(ctx, err) }()
|
||||
|
||||
config := o.Config()
|
||||
if config.EmbeddingModel == "" {
|
||||
return nil, errors.New("embedding model is not set")
|
||||
}
|
||||
splits, err := o.embeddingSpliter.Split(labels)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "split embedding")
|
||||
}
|
||||
|
||||
vecs := make([][]float32, 0, len(splits))
|
||||
for _, split := range splits {
|
||||
text := runtimeutil.Must1(json.Marshal(split))
|
||||
vec, err := o.Embedding(ctx, string(text))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "embedding")
|
||||
}
|
||||
vecs = append(vecs, vec)
|
||||
}
|
||||
|
||||
return vecs, nil
|
||||
}
|
||||
|
||||
func (o *openai) Embedding(ctx context.Context, s string) (value []float32, err error) {
|
||||
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "Embedding")...)
|
||||
defer func() { telemetry.End(ctx, err) }()
|
||||
|
||||
config := o.Config()
|
||||
if config.EmbeddingModel == "" {
|
||||
return nil, errors.New("embedding model is not set")
|
||||
}
|
||||
vec, err := o.client.CreateEmbeddings(ctx, oai.EmbeddingRequest{
|
||||
Input: []string{s},
|
||||
Model: oai.EmbeddingModel(config.EmbeddingModel),
|
||||
EncodingFormat: oai.EmbeddingEncodingFormatFloat,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create embeddings")
|
||||
}
|
||||
if len(vec.Data) == 0 {
|
||||
return nil, errors.New("no embedding data returned")
|
||||
}
|
||||
|
||||
lvs := []string{o.Name(), o.Instance(), "Embedding"}
|
||||
promptTokens.WithLabelValues(lvs...).Add(float64(vec.Usage.PromptTokens))
|
||||
completionTokens.WithLabelValues(lvs...).Add(float64(vec.Usage.CompletionTokens))
|
||||
totalTokens.WithLabelValues(lvs...).Add(float64(vec.Usage.TotalTokens))
|
||||
|
||||
return vec.Data[0].Embedding, nil
|
||||
}
|
||||
Reference in New Issue
Block a user