This commit is contained in:
glidea
2025-04-19 15:50:26 +08:00
commit 8b33df8a05
109 changed files with 24407 additions and 0 deletions

View File

@@ -0,0 +1,131 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package llm
import (
"math"
"slices"
"github.com/glidea/zenfeed/pkg/model"
)
type embeddingSpliter interface {
Split(ls model.Labels) ([]model.Labels, error)
}
func newEmbeddingSpliter(maxLabelValueTokens, overlapTokens int) embeddingSpliter {
if maxLabelValueTokens <= 0 {
maxLabelValueTokens = 1024
}
if overlapTokens <= 0 {
overlapTokens = 64
}
if overlapTokens > maxLabelValueTokens {
overlapTokens = maxLabelValueTokens / 10
}
return &embeddingSpliterImpl{maxLabelValueTokens: maxLabelValueTokens, overlapTokens: overlapTokens}
}
type embeddingSpliterImpl struct {
maxLabelValueTokens int
overlapTokens int
}
func (e *embeddingSpliterImpl) Split(ls model.Labels) ([]model.Labels, error) {
var (
short = make(model.Labels, 0, len(ls))
long = make(model.Labels, 0, 1)
longTokens = make([]int, 0, 1)
)
for _, l := range ls {
tokens := e.estimateTokens(l.Value)
if tokens <= e.maxLabelValueTokens {
short = append(short, l)
} else {
long = append(long, l)
longTokens = append(longTokens, tokens)
}
}
if len(long) == 0 {
return []model.Labels{ls}, nil
}
var (
common = short
splits = make([]model.Labels, 0, len(long)*2)
)
for i := range long {
parts := e.split(long[i].Value, longTokens[i])
for _, p := range parts {
com := slices.Clone(common)
s := append(com, model.Label{Key: long[i].Key, Value: p})
splits = append(splits, s)
}
}
return splits, nil
}
func (e *embeddingSpliterImpl) split(value string, tokens int) []string {
var (
results = make([]string, 0)
chars = []rune(value)
)
// Estimate the number of characters per token
avgCharsPerToken := float64(len(chars)) / float64(tokens)
// Calculate the approximate number of characters corresponding to maxLabelValueTokens tokens.
charsPerSegment := int(float64(e.maxLabelValueTokens) * avgCharsPerToken)
// The number of characters corresponding to a fixed overlap of 64 tokens.
overlapChars := int(float64(e.overlapTokens) * avgCharsPerToken)
// Actual step length = segment length - overlap.
charStep := charsPerSegment - overlapChars
for start := 0; start < len(chars); {
end := min(start+charsPerSegment, len(chars))
segment := string(chars[start:end])
results = append(results, segment)
if end == len(chars) {
break
}
start += charStep
}
return results
}
func (e *embeddingSpliterImpl) estimateTokens(text string) int {
latinChars := 0
otherChars := 0
for _, r := range text {
if r <= 127 {
latinChars++
} else {
otherChars++
}
}
// Rough estimate:
// - English and punctuation: about 0.25 tokens/char (4 characters ≈ 1 token).
// - Chinese and other non-Latin characters: about 1.5 tokens/char.
return int(math.Round(float64(latinChars)/4 + float64(otherChars)*3/2))
}

View File

@@ -0,0 +1,158 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package llm
import (
"testing"
. "github.com/onsi/gomega"
"github.com/glidea/zenfeed/pkg/model"
"github.com/glidea/zenfeed/pkg/test"
)
func TestEmbeddingSpliter_Split(t *testing.T) {
RegisterTestingT(t)
type givenDetail struct {
maxLabelValueTokens int
overlapTokens int
}
type whenDetail struct {
labels model.Labels
}
type thenExpected struct {
splits []model.Labels
err string
}
tests := []test.Case[givenDetail, whenDetail, thenExpected]{
{
Scenario: "Split labels with all short values",
Given: "an embedding spliter with max token limit",
When: "splitting labels with all values under token limit",
Then: "should return original labels as single split",
GivenDetail: givenDetail{
maxLabelValueTokens: 1024,
},
WhenDetail: whenDetail{
labels: model.Labels{
{Key: "title", Value: "Short title"},
{Key: "description", Value: "Short description"},
},
},
ThenExpected: thenExpected{
splits: []model.Labels{
{
{Key: "title", Value: "Short title"},
{Key: "description", Value: "Short description"},
},
},
},
},
{
Scenario: "Split labels with one long value",
Given: "an embedding spliter with max token limit",
When: "splitting labels with one value exceeding token limit",
Then: "should split the long value and combine with common labels",
GivenDetail: givenDetail{
maxLabelValueTokens: 10, // Small limit to force splitting.
overlapTokens: 1,
},
WhenDetail: whenDetail{
labels: model.Labels{
{Key: "title", Value: "Short title"},
{Key: "content", Value: "This is a long content that exceeds the token limit and needs to be split into multiple parts"},
},
},
ThenExpected: thenExpected{
splits: []model.Labels{
{
{Key: "title", Value: "Short title"},
{Key: "content", Value: "This is a long content that exceeds the "},
},
{
{Key: "title", Value: "Short title"},
{Key: "content", Value: "the token limit and needs to be split in"},
},
{
{Key: "title", Value: "Short title"},
{Key: "content", Value: "t into multiple parts"},
},
},
},
},
{
Scenario: "Handle non-Latin characters",
Given: "an embedding spliter with max token limit",
When: "splitting labels with non-Latin characters",
Then: "should correctly estimate tokens and split accordingly",
GivenDetail: givenDetail{
maxLabelValueTokens: 10, // Small limit to force splitting.
overlapTokens: 2,
},
WhenDetail: whenDetail{
labels: model.Labels{
{Key: "title", Value: "Short title"},
{Key: "content", Value: "中文内容需要被分割因为它超过了令牌限制"}, // Chinese content that needs to be split.
},
},
ThenExpected: thenExpected{
splits: []model.Labels{
{
{Key: "title", Value: "Short title"},
{Key: "content", Value: "中文内容需要"},
},
{
{Key: "title", Value: "Short title"},
{Key: "content", Value: "要被分割因为"},
},
{
{Key: "title", Value: "Short title"},
{Key: "content", Value: "为它超过了令"},
},
{
{Key: "title", Value: "Short title"},
{Key: "content", Value: "令牌限制"},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.Scenario, func(t *testing.T) {
// Given.
spliter := newEmbeddingSpliter(tt.GivenDetail.maxLabelValueTokens, tt.GivenDetail.overlapTokens)
// When.
splits, err := spliter.Split(tt.WhenDetail.labels)
// Then.
if tt.ThenExpected.err != "" {
Expect(err).NotTo(BeNil())
Expect(err.Error()).To(ContainSubstring(tt.ThenExpected.err))
} else {
Expect(err).To(BeNil())
Expect(len(splits)).To(Equal(len(tt.ThenExpected.splits)))
for i, expectedSplit := range tt.ThenExpected.splits {
Expect(splits[i]).To(Equal(expectedSplit))
}
}
})
}
}

420
pkg/llm/llm.go Normal file
View File

@@ -0,0 +1,420 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package llm
import (
"context"
"reflect"
"strconv"
"sync"
"time"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/stretchr/testify/mock"
"github.com/glidea/zenfeed/pkg/component"
"github.com/glidea/zenfeed/pkg/config"
"github.com/glidea/zenfeed/pkg/model"
"github.com/glidea/zenfeed/pkg/storage/kv"
"github.com/glidea/zenfeed/pkg/telemetry/log"
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
"github.com/glidea/zenfeed/pkg/util/hash"
)
// --- Interface code block ---
type LLM interface {
component.Component
String(ctx context.Context, messages []string) (string, error)
EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error)
Embedding(ctx context.Context, text string) ([]float32, error)
}
type Config struct {
Name string
Default bool
Provider ProviderType
Endpoint string
APIKey string
Model, EmbeddingModel string
Temperature float32
}
type ProviderType string
const (
ProviderTypeOpenAI ProviderType = "openai"
ProviderTypeOpenRouter ProviderType = "openrouter"
ProviderTypeDeepSeek ProviderType = "deepseek"
ProviderTypeGemini ProviderType = "gemini"
ProviderTypeVolc ProviderType = "volc" // Rename MaaS to ARK. 😄
ProviderTypeSiliconFlow ProviderType = "siliconflow"
)
var defaultEndpoints = map[ProviderType]string{
ProviderTypeOpenAI: "https://api.openai.com/v1",
ProviderTypeOpenRouter: "https://openrouter.ai/api/v1",
ProviderTypeDeepSeek: "https://api.deepseek.com/v1",
ProviderTypeGemini: "https://generativelanguage.googleapis.com/v1beta/openai",
ProviderTypeVolc: "https://ark.cn-beijing.volces.com/api/v3",
ProviderTypeSiliconFlow: "https://api.siliconflow.cn/v1",
}
func (c *Config) Validate() error { //nolint:cyclop
if c.Name == "" {
return errors.New("name is required")
}
switch c.Provider {
case "":
c.Provider = ProviderTypeOpenAI
case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek,
ProviderTypeGemini, ProviderTypeVolc, ProviderTypeSiliconFlow:
default:
return errors.Errorf("invalid provider: %s", c.Provider)
}
if c.Endpoint == "" {
c.Endpoint = defaultEndpoints[c.Provider]
}
if c.APIKey == "" {
return errors.New("api key is required")
}
if c.Model == "" && c.EmbeddingModel == "" {
return errors.New("model or embedding model is required")
}
if c.Temperature < 0 || c.Temperature > 2 {
return errors.Errorf("invalid temperature: %f, should be in range [0, 2]", c.Temperature)
}
return nil
}
var (
promptTokens = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: model.AppName,
Subsystem: "llm",
Name: "prompt_tokens",
},
[]string{telemetrymodel.KeyComponent, telemetrymodel.KeyComponentInstance, telemetrymodel.KeyOperation},
)
completionTokens = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: model.AppName,
Subsystem: "llm",
Name: "completion_tokens",
},
[]string{telemetrymodel.KeyComponent, telemetrymodel.KeyComponentInstance, telemetrymodel.KeyOperation},
)
totalTokens = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: model.AppName,
Subsystem: "llm",
Name: "total_tokens",
},
[]string{telemetrymodel.KeyComponent, telemetrymodel.KeyComponentInstance, telemetrymodel.KeyOperation},
)
)
// --- Factory code block ---
type FactoryConfig struct {
LLMs []Config
defaultLLM string
}
func (c *FactoryConfig) Validate() error {
if len(c.LLMs) == 0 {
return errors.New("no llm config")
}
for i := range c.LLMs {
if err := (&c.LLMs[i]).Validate(); err != nil {
return errors.Wrapf(err, "validate llm config %s", c.LLMs[i].Name)
}
}
if len(c.LLMs) == 1 {
c.LLMs[0].Default = true
c.defaultLLM = c.LLMs[0].Name
return nil
}
defaults := 0
for _, llm := range c.LLMs {
if llm.Default {
c.defaultLLM = llm.Name
defaults++
}
}
if defaults > 1 {
return errors.New("multiple llm configs are default")
}
return nil
}
func (c *FactoryConfig) From(app *config.App) {
for _, llm := range app.LLMs {
c.LLMs = append(c.LLMs, Config{
Name: llm.Name,
Default: llm.Default,
Provider: ProviderType(llm.Provider),
Endpoint: llm.Endpoint,
APIKey: llm.APIKey,
Model: llm.Model,
EmbeddingModel: llm.EmbeddingModel,
Temperature: llm.Temperature,
})
}
}
type FactoryDependencies struct {
KVStorage kv.Storage
}
// Factory is a factory for creating LLM instances.
// If name is empty or not found, it will return the default.
type Factory interface {
component.Component
config.Watcher
Get(name string) LLM
}
func NewFactory(
instance string,
app *config.App,
dependencies FactoryDependencies,
mockOn ...component.MockOption,
) (Factory, error) {
if len(mockOn) > 0 {
mf := &mockFactory{}
getCall := mf.On("Get", mock.Anything)
getCall.Run(func(args mock.Arguments) {
m := &mockLLM{}
component.MockOptions(mockOn).Apply(&m.Mock)
getCall.Return(m, nil)
})
mf.On("Reload", mock.Anything).Return(nil)
return mf, nil
}
config := &FactoryConfig{}
config.From(app)
if err := config.Validate(); err != nil {
return nil, errors.Wrap(err, "validate config")
}
f := &factory{
Base: component.New(&component.BaseConfig[FactoryConfig, FactoryDependencies]{
Name: "LLMFactory",
Instance: instance,
Config: config,
Dependencies: dependencies,
}),
llms: make(map[string]LLM),
}
f.initLLMs()
return f, nil
}
type factory struct {
*component.Base[FactoryConfig, FactoryDependencies]
defaultLLM LLM
llms map[string]LLM
mu sync.Mutex
}
func (f *factory) Run() error {
for _, llm := range f.llms {
if err := component.RunUntilReady(f.Context(), llm, 10*time.Second); err != nil {
return errors.Wrapf(err, "run llm %s", llm.Name())
}
}
f.MarkReady()
<-f.Context().Done()
return nil
}
func (f *factory) Close() error {
f.mu.Lock()
defer f.mu.Unlock()
for _, llm := range f.llms {
_ = llm.Close()
}
return nil
}
func (f *factory) Reload(app *config.App) error {
newConfig := &FactoryConfig{}
newConfig.From(app)
if err := newConfig.Validate(); err != nil {
return errors.Wrap(err, "validate config")
}
if reflect.DeepEqual(f.Config(), newConfig) {
log.Debug(f.Context(), "no changes in llm config")
return nil
}
// Reload the LLMs.
f.mu.Lock()
defer f.mu.Unlock()
f.SetConfig(newConfig)
// Close the old LLMs.
for _, llm := range f.llms {
_ = llm.Close()
}
// Recreate the LLMs.
f.initLLMs()
return nil
}
func (f *factory) Get(name string) LLM {
f.mu.Lock()
defer f.mu.Unlock()
if name == "" {
return f.defaultLLM
}
for _, llmC := range f.Config().LLMs {
if llmC.Name != name {
continue
}
if f.llms[name] == nil {
llm := f.new(&llmC)
f.llms[name] = llm
}
return f.llms[name]
}
return f.defaultLLM
}
func (f *factory) new(c *Config) LLM {
switch c.Provider {
case ProviderTypeOpenAI, ProviderTypeOpenRouter, ProviderTypeDeepSeek, ProviderTypeGemini, ProviderTypeVolc, ProviderTypeSiliconFlow: //nolint:lll
return newCached(newOpenAI(c), f.Dependencies().KVStorage)
default:
return newCached(newOpenAI(c), f.Dependencies().KVStorage)
}
}
func (f *factory) initLLMs() {
var (
config = f.Config()
llms = make(map[string]LLM, len(config.LLMs))
defaultLLM LLM
)
for _, llmC := range config.LLMs {
llm := f.new(&llmC)
llms[llmC.Name] = llm
if llmC.Name == config.defaultLLM {
defaultLLM = llm
}
}
f.llms = llms
f.defaultLLM = defaultLLM
}
type mockFactory struct {
component.Mock
}
func (m *mockFactory) Get(name string) LLM {
args := m.Called(name)
return args.Get(0).(LLM)
}
func (m *mockFactory) Reload(app *config.App) error {
args := m.Called(app)
return args.Error(0)
}
// --- Implementation code block ---
type cached struct {
LLM
kvStorage kv.Storage
}
func newCached(llm LLM, kvStorage kv.Storage) LLM {
return &cached{
LLM: llm,
kvStorage: kvStorage,
}
}
func (c *cached) String(ctx context.Context, messages []string) (string, error) {
key := hash.Sum64s(messages)
keyStr := strconv.FormatUint(key, 10)
value, err := c.kvStorage.Get(ctx, keyStr)
switch {
case err == nil:
return value, nil
case errors.Is(err, kv.ErrNotFound):
break
default:
return "", errors.Wrap(err, "get from kv storage")
}
value, err = c.LLM.String(ctx, messages)
if err != nil {
return "", err
}
if err = c.kvStorage.Set(ctx, keyStr, value, 65*time.Minute); err != nil {
log.Error(ctx, err, "set to kv storage")
}
return value, nil
}
type mockLLM struct {
component.Mock
}
func (m *mockLLM) String(ctx context.Context, messages []string) (string, error) {
args := m.Called(ctx, messages)
return args.Get(0).(string), args.Error(1)
}
func (m *mockLLM) EmbeddingLabels(ctx context.Context, labels model.Labels) ([][]float32, error) {
args := m.Called(ctx, labels)
return args.Get(0).([][]float32), args.Error(1)
}
func (m *mockLLM) Embedding(ctx context.Context, text string) ([]float32, error) {
args := m.Called(ctx, text)
return args.Get(0).([]float32), args.Error(1)
}

146
pkg/llm/openai.go Normal file
View File

@@ -0,0 +1,146 @@
// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package llm
import (
"context"
"encoding/json"
"github.com/pkg/errors"
oai "github.com/sashabaranov/go-openai"
"github.com/glidea/zenfeed/pkg/component"
"github.com/glidea/zenfeed/pkg/model"
"github.com/glidea/zenfeed/pkg/telemetry"
telemetrymodel "github.com/glidea/zenfeed/pkg/telemetry/model"
runtimeutil "github.com/glidea/zenfeed/pkg/util/runtime"
)
type openai struct {
*component.Base[Config, struct{}]
client *oai.Client
embeddingSpliter embeddingSpliter
}
func newOpenAI(c *Config) LLM {
config := oai.DefaultConfig(c.APIKey)
config.BaseURL = c.Endpoint
client := oai.NewClientWithConfig(config)
embeddingSpliter := newEmbeddingSpliter(2048, 64)
return &openai{
Base: component.New(&component.BaseConfig[Config, struct{}]{
Name: "LLM/openai",
Instance: c.Name,
Config: c,
}),
client: client,
embeddingSpliter: embeddingSpliter,
}
}
func (o *openai) String(ctx context.Context, messages []string) (value string, err error) {
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "String")...)
defer func() { telemetry.End(ctx, err) }()
config := o.Config()
if config.Model == "" {
return "", errors.New("model is not set")
}
msg := make([]oai.ChatCompletionMessage, 0, len(messages))
for _, m := range messages {
msg = append(msg, oai.ChatCompletionMessage{
Role: oai.ChatMessageRoleUser,
Content: m,
})
}
req := oai.ChatCompletionRequest{
Model: config.Model,
Messages: msg,
Temperature: config.Temperature,
}
resp, err := o.client.CreateChatCompletion(ctx, req)
if err != nil {
return "", errors.Wrap(err, "create chat completion")
}
if len(resp.Choices) == 0 {
return "", errors.New("no completion choices returned")
}
lvs := []string{o.Name(), o.Instance(), "String"}
promptTokens.WithLabelValues(lvs...).Add(float64(resp.Usage.PromptTokens))
completionTokens.WithLabelValues(lvs...).Add(float64(resp.Usage.CompletionTokens))
totalTokens.WithLabelValues(lvs...).Add(float64(resp.Usage.TotalTokens))
return resp.Choices[0].Message.Content, nil
}
func (o *openai) EmbeddingLabels(ctx context.Context, labels model.Labels) (value [][]float32, err error) {
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "EmbeddingLabels")...)
defer func() { telemetry.End(ctx, err) }()
config := o.Config()
if config.EmbeddingModel == "" {
return nil, errors.New("embedding model is not set")
}
splits, err := o.embeddingSpliter.Split(labels)
if err != nil {
return nil, errors.Wrap(err, "split embedding")
}
vecs := make([][]float32, 0, len(splits))
for _, split := range splits {
text := runtimeutil.Must1(json.Marshal(split))
vec, err := o.Embedding(ctx, string(text))
if err != nil {
return nil, errors.Wrap(err, "embedding")
}
vecs = append(vecs, vec)
}
return vecs, nil
}
func (o *openai) Embedding(ctx context.Context, s string) (value []float32, err error) {
ctx = telemetry.StartWith(ctx, append(o.TelemetryLabels(), telemetrymodel.KeyOperation, "Embedding")...)
defer func() { telemetry.End(ctx, err) }()
config := o.Config()
if config.EmbeddingModel == "" {
return nil, errors.New("embedding model is not set")
}
vec, err := o.client.CreateEmbeddings(ctx, oai.EmbeddingRequest{
Input: []string{s},
Model: oai.EmbeddingModel(config.EmbeddingModel),
EncodingFormat: oai.EmbeddingEncodingFormatFloat,
})
if err != nil {
return nil, errors.Wrap(err, "create embeddings")
}
if len(vec.Data) == 0 {
return nil, errors.New("no embedding data returned")
}
lvs := []string{o.Name(), o.Instance(), "Embedding"}
promptTokens.WithLabelValues(lvs...).Add(float64(vec.Usage.PromptTokens))
completionTokens.WithLabelValues(lvs...).Add(float64(vec.Usage.CompletionTokens))
totalTokens.WithLabelValues(lvs...).Add(float64(vec.Usage.TotalTokens))
return vec.Data[0].Embedding, nil
}