Files
zenfeed/pkg/llm/embedding_spliter.go
glidea 8b33df8a05 init
2025-04-19 15:50:26 +08:00

132 lines
3.4 KiB
Go

// Copyright (C) 2025 wangyusong
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package llm
import (
"math"
"slices"
"github.com/glidea/zenfeed/pkg/model"
)
type embeddingSpliter interface {
Split(ls model.Labels) ([]model.Labels, error)
}
func newEmbeddingSpliter(maxLabelValueTokens, overlapTokens int) embeddingSpliter {
if maxLabelValueTokens <= 0 {
maxLabelValueTokens = 1024
}
if overlapTokens <= 0 {
overlapTokens = 64
}
if overlapTokens > maxLabelValueTokens {
overlapTokens = maxLabelValueTokens / 10
}
return &embeddingSpliterImpl{maxLabelValueTokens: maxLabelValueTokens, overlapTokens: overlapTokens}
}
type embeddingSpliterImpl struct {
maxLabelValueTokens int
overlapTokens int
}
func (e *embeddingSpliterImpl) Split(ls model.Labels) ([]model.Labels, error) {
var (
short = make(model.Labels, 0, len(ls))
long = make(model.Labels, 0, 1)
longTokens = make([]int, 0, 1)
)
for _, l := range ls {
tokens := e.estimateTokens(l.Value)
if tokens <= e.maxLabelValueTokens {
short = append(short, l)
} else {
long = append(long, l)
longTokens = append(longTokens, tokens)
}
}
if len(long) == 0 {
return []model.Labels{ls}, nil
}
var (
common = short
splits = make([]model.Labels, 0, len(long)*2)
)
for i := range long {
parts := e.split(long[i].Value, longTokens[i])
for _, p := range parts {
com := slices.Clone(common)
s := append(com, model.Label{Key: long[i].Key, Value: p})
splits = append(splits, s)
}
}
return splits, nil
}
func (e *embeddingSpliterImpl) split(value string, tokens int) []string {
var (
results = make([]string, 0)
chars = []rune(value)
)
// Estimate the number of characters per token
avgCharsPerToken := float64(len(chars)) / float64(tokens)
// Calculate the approximate number of characters corresponding to maxLabelValueTokens tokens.
charsPerSegment := int(float64(e.maxLabelValueTokens) * avgCharsPerToken)
// The number of characters corresponding to a fixed overlap of 64 tokens.
overlapChars := int(float64(e.overlapTokens) * avgCharsPerToken)
// Actual step length = segment length - overlap.
charStep := charsPerSegment - overlapChars
for start := 0; start < len(chars); {
end := min(start+charsPerSegment, len(chars))
segment := string(chars[start:end])
results = append(results, segment)
if end == len(chars) {
break
}
start += charStep
}
return results
}
func (e *embeddingSpliterImpl) estimateTokens(text string) int {
latinChars := 0
otherChars := 0
for _, r := range text {
if r <= 127 {
latinChars++
} else {
otherChars++
}
}
// Rough estimate:
// - English and punctuation: about 0.25 tokens/char (4 characters ≈ 1 token).
// - Chinese and other non-Latin characters: about 1.5 tokens/char.
return int(math.Round(float64(latinChars)/4 + float64(otherChars)*3/2))
}