mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-08 01:02:22 +08:00
590 lines
20 KiB
Python
590 lines
20 KiB
Python
"""
|
|
Storage layer for memory using SQLite + FTS5
|
|
|
|
Provides vector and keyword search capabilities
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import sqlite3
|
|
import json
|
|
import hashlib
|
|
from typing import List, Dict, Optional, Any
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class MemoryChunk:
|
|
"""Represents a memory chunk with text and embedding"""
|
|
id: str
|
|
user_id: Optional[str]
|
|
scope: str # "shared" | "user" | "session"
|
|
source: str # "memory" | "session"
|
|
path: str
|
|
start_line: int
|
|
end_line: int
|
|
text: str
|
|
embedding: Optional[List[float]]
|
|
hash: str
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
@dataclass
|
|
class SearchResult:
|
|
"""Search result with score and snippet"""
|
|
path: str
|
|
start_line: int
|
|
end_line: int
|
|
score: float
|
|
snippet: str
|
|
source: str
|
|
user_id: Optional[str] = None
|
|
|
|
|
|
class MemoryStorage:
|
|
"""SQLite-based storage with FTS5 for keyword search"""
|
|
|
|
def __init__(self, db_path: Path):
|
|
self.db_path = db_path
|
|
self.conn: Optional[sqlite3.Connection] = None
|
|
self.fts5_available = False # Track FTS5 availability
|
|
self._init_db()
|
|
|
|
def _check_fts5_support(self) -> bool:
|
|
"""Check if SQLite has FTS5 support"""
|
|
try:
|
|
self.conn.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(test)")
|
|
self.conn.execute("DROP TABLE IF EXISTS fts5_test")
|
|
return True
|
|
except sqlite3.OperationalError as e:
|
|
if "no such module: fts5" in str(e):
|
|
return False
|
|
raise
|
|
|
|
def _init_db(self):
|
|
"""Initialize database with schema"""
|
|
try:
|
|
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
|
self.conn.row_factory = sqlite3.Row
|
|
|
|
# Check FTS5 support
|
|
self.fts5_available = self._check_fts5_support()
|
|
if not self.fts5_available:
|
|
from common.log import logger
|
|
logger.debug("[MemoryStorage] FTS5 not available, using LIKE-based keyword search")
|
|
|
|
# Check database integrity
|
|
try:
|
|
result = self.conn.execute("PRAGMA integrity_check").fetchone()
|
|
if result[0] != 'ok':
|
|
print(f"⚠️ Database integrity check failed: {result[0]}")
|
|
print(f" Recreating database...")
|
|
self.conn.close()
|
|
self.conn = None
|
|
# Remove corrupted database
|
|
self.db_path.unlink(missing_ok=True)
|
|
# Remove WAL files
|
|
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
|
|
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
|
|
# Reconnect to create new database
|
|
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
|
self.conn.row_factory = sqlite3.Row
|
|
except sqlite3.DatabaseError:
|
|
# Database is corrupted, recreate it
|
|
print(f"⚠️ Database is corrupted, recreating...")
|
|
if self.conn:
|
|
self.conn.close()
|
|
self.conn = None
|
|
self.db_path.unlink(missing_ok=True)
|
|
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
|
|
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
|
|
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
|
self.conn.row_factory = sqlite3.Row
|
|
|
|
# Enable WAL mode for better concurrency
|
|
self.conn.execute("PRAGMA journal_mode=WAL")
|
|
# Set busy timeout to avoid "database is locked" errors
|
|
self.conn.execute("PRAGMA busy_timeout=5000")
|
|
except Exception as e:
|
|
print(f"⚠️ Unexpected error during database initialization: {e}")
|
|
raise
|
|
|
|
# Create chunks table with embeddings
|
|
self.conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS chunks (
|
|
id TEXT PRIMARY KEY,
|
|
user_id TEXT,
|
|
scope TEXT NOT NULL DEFAULT 'shared',
|
|
source TEXT NOT NULL DEFAULT 'memory',
|
|
path TEXT NOT NULL,
|
|
start_line INTEGER NOT NULL,
|
|
end_line INTEGER NOT NULL,
|
|
text TEXT NOT NULL,
|
|
embedding TEXT,
|
|
hash TEXT NOT NULL,
|
|
metadata TEXT,
|
|
created_at INTEGER DEFAULT (strftime('%s', 'now')),
|
|
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
|
)
|
|
""")
|
|
|
|
# Create indexes
|
|
self.conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_user
|
|
ON chunks(user_id)
|
|
""")
|
|
|
|
self.conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_scope
|
|
ON chunks(scope)
|
|
""")
|
|
|
|
self.conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_hash
|
|
ON chunks(path, hash)
|
|
""")
|
|
|
|
# Create FTS5 virtual table for keyword search (only if supported)
|
|
if self.fts5_available:
|
|
# Use default unicode61 tokenizer (stable and compatible)
|
|
# For CJK support, we'll use LIKE queries as fallback
|
|
self.conn.execute("""
|
|
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
|
text,
|
|
id UNINDEXED,
|
|
user_id UNINDEXED,
|
|
path UNINDEXED,
|
|
source UNINDEXED,
|
|
scope UNINDEXED,
|
|
content='chunks',
|
|
content_rowid='rowid'
|
|
)
|
|
""")
|
|
|
|
# Create triggers to keep FTS in sync
|
|
self.conn.execute("""
|
|
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
|
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
|
|
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
|
|
END
|
|
""")
|
|
|
|
self.conn.execute("""
|
|
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
|
DELETE FROM chunks_fts WHERE rowid = old.rowid;
|
|
END
|
|
""")
|
|
|
|
self.conn.execute("""
|
|
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
|
UPDATE chunks_fts SET text = new.text, id = new.id,
|
|
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
|
|
WHERE rowid = new.rowid;
|
|
END
|
|
""")
|
|
|
|
# Create files metadata table
|
|
self.conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS files (
|
|
path TEXT PRIMARY KEY,
|
|
source TEXT NOT NULL DEFAULT 'memory',
|
|
hash TEXT NOT NULL,
|
|
mtime INTEGER NOT NULL,
|
|
size INTEGER NOT NULL,
|
|
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
|
)
|
|
""")
|
|
|
|
self.conn.commit()
|
|
|
|
def save_chunk(self, chunk: MemoryChunk):
|
|
"""Save a memory chunk"""
|
|
self.conn.execute("""
|
|
INSERT OR REPLACE INTO chunks
|
|
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
|
""", (
|
|
chunk.id,
|
|
chunk.user_id,
|
|
chunk.scope,
|
|
chunk.source,
|
|
chunk.path,
|
|
chunk.start_line,
|
|
chunk.end_line,
|
|
chunk.text,
|
|
json.dumps(chunk.embedding) if chunk.embedding else None,
|
|
chunk.hash,
|
|
json.dumps(chunk.metadata) if chunk.metadata else None
|
|
))
|
|
self.conn.commit()
|
|
|
|
def save_chunks_batch(self, chunks: List[MemoryChunk]):
|
|
"""Save multiple chunks in a batch"""
|
|
self.conn.executemany("""
|
|
INSERT OR REPLACE INTO chunks
|
|
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
|
""", [
|
|
(
|
|
c.id, c.user_id, c.scope, c.source, c.path,
|
|
c.start_line, c.end_line, c.text,
|
|
json.dumps(c.embedding) if c.embedding else None,
|
|
c.hash,
|
|
json.dumps(c.metadata) if c.metadata else None
|
|
)
|
|
for c in chunks
|
|
])
|
|
self.conn.commit()
|
|
|
|
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
|
|
"""Get a chunk by ID"""
|
|
row = self.conn.execute("""
|
|
SELECT * FROM chunks WHERE id = ?
|
|
""", (chunk_id,)).fetchone()
|
|
|
|
if not row:
|
|
return None
|
|
|
|
return self._row_to_chunk(row)
|
|
|
|
def search_vector(
|
|
self,
|
|
query_embedding: List[float],
|
|
user_id: Optional[str] = None,
|
|
scopes: List[str] = None,
|
|
limit: int = 10
|
|
) -> List[SearchResult]:
|
|
"""
|
|
Vector similarity search using in-memory cosine similarity
|
|
(sqlite-vec can be added later for better performance)
|
|
"""
|
|
if scopes is None:
|
|
scopes = ["shared"]
|
|
if user_id:
|
|
scopes.append("user")
|
|
|
|
# Build query
|
|
scope_placeholders = ','.join('?' * len(scopes))
|
|
params = scopes
|
|
|
|
if user_id:
|
|
query = f"""
|
|
SELECT * FROM chunks
|
|
WHERE scope IN ({scope_placeholders})
|
|
AND (scope = 'shared' OR user_id = ?)
|
|
AND embedding IS NOT NULL
|
|
"""
|
|
params.append(user_id)
|
|
else:
|
|
query = f"""
|
|
SELECT * FROM chunks
|
|
WHERE scope IN ({scope_placeholders})
|
|
AND embedding IS NOT NULL
|
|
"""
|
|
|
|
rows = self.conn.execute(query, params).fetchall()
|
|
|
|
# Calculate cosine similarity
|
|
results = []
|
|
for row in rows:
|
|
embedding = json.loads(row['embedding'])
|
|
similarity = self._cosine_similarity(query_embedding, embedding)
|
|
|
|
if similarity > 0:
|
|
results.append((similarity, row))
|
|
|
|
# Sort by similarity and limit
|
|
results.sort(key=lambda x: x[0], reverse=True)
|
|
results = results[:limit]
|
|
|
|
return [
|
|
SearchResult(
|
|
path=row['path'],
|
|
start_line=row['start_line'],
|
|
end_line=row['end_line'],
|
|
score=score,
|
|
snippet=self._truncate_text(row['text'], 500),
|
|
source=row['source'],
|
|
user_id=row['user_id']
|
|
)
|
|
for score, row in results
|
|
]
|
|
|
|
def search_keyword(
|
|
self,
|
|
query: str,
|
|
user_id: Optional[str] = None,
|
|
scopes: List[str] = None,
|
|
limit: int = 10
|
|
) -> List[SearchResult]:
|
|
"""
|
|
Keyword search using FTS5 + LIKE fallback
|
|
|
|
Strategy:
|
|
1. If FTS5 available: Try FTS5 search first (good for English and word-based languages)
|
|
2. If no FTS5 or no results and query contains CJK: Use LIKE search
|
|
"""
|
|
if scopes is None:
|
|
scopes = ["shared"]
|
|
if user_id:
|
|
scopes.append("user")
|
|
|
|
# Try FTS5 search first (if available)
|
|
if self.fts5_available:
|
|
fts_results = self._search_fts5(query, user_id, scopes, limit)
|
|
if fts_results:
|
|
return fts_results
|
|
|
|
# Fallback to LIKE search (always for CJK, or if FTS5 not available)
|
|
if not self.fts5_available or MemoryStorage._contains_cjk(query):
|
|
return self._search_like(query, user_id, scopes, limit)
|
|
|
|
return []
|
|
|
|
def _search_fts5(
|
|
self,
|
|
query: str,
|
|
user_id: Optional[str],
|
|
scopes: List[str],
|
|
limit: int
|
|
) -> List[SearchResult]:
|
|
"""FTS5 full-text search"""
|
|
fts_query = self._build_fts_query(query)
|
|
if not fts_query:
|
|
return []
|
|
|
|
scope_placeholders = ','.join('?' * len(scopes))
|
|
params = [fts_query] + scopes
|
|
|
|
if user_id:
|
|
sql_query = f"""
|
|
SELECT chunks.*, bm25(chunks_fts) as rank
|
|
FROM chunks_fts
|
|
JOIN chunks ON chunks.id = chunks_fts.id
|
|
WHERE chunks_fts MATCH ?
|
|
AND chunks.scope IN ({scope_placeholders})
|
|
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
|
|
ORDER BY rank
|
|
LIMIT ?
|
|
"""
|
|
params.extend([user_id, limit])
|
|
else:
|
|
sql_query = f"""
|
|
SELECT chunks.*, bm25(chunks_fts) as rank
|
|
FROM chunks_fts
|
|
JOIN chunks ON chunks.id = chunks_fts.id
|
|
WHERE chunks_fts MATCH ?
|
|
AND chunks.scope IN ({scope_placeholders})
|
|
ORDER BY rank
|
|
LIMIT ?
|
|
"""
|
|
params.append(limit)
|
|
|
|
try:
|
|
rows = self.conn.execute(sql_query, params).fetchall()
|
|
return [
|
|
SearchResult(
|
|
path=row['path'],
|
|
start_line=row['start_line'],
|
|
end_line=row['end_line'],
|
|
score=self._bm25_rank_to_score(row['rank']),
|
|
snippet=self._truncate_text(row['text'], 500),
|
|
source=row['source'],
|
|
user_id=row['user_id']
|
|
)
|
|
for row in rows
|
|
]
|
|
except Exception:
|
|
return []
|
|
|
|
def _search_like(
|
|
self,
|
|
query: str,
|
|
user_id: Optional[str],
|
|
scopes: List[str],
|
|
limit: int
|
|
) -> List[SearchResult]:
|
|
"""LIKE-based search for CJK characters"""
|
|
import re
|
|
# Extract CJK words (2+ characters)
|
|
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
|
|
if not cjk_words:
|
|
return []
|
|
|
|
scope_placeholders = ','.join('?' * len(scopes))
|
|
|
|
# Build LIKE conditions for each word
|
|
like_conditions = []
|
|
params = []
|
|
for word in cjk_words:
|
|
like_conditions.append("text LIKE ?")
|
|
params.append(f'%{word}%')
|
|
|
|
where_clause = ' OR '.join(like_conditions)
|
|
params.extend(scopes)
|
|
|
|
if user_id:
|
|
sql_query = f"""
|
|
SELECT * FROM chunks
|
|
WHERE ({where_clause})
|
|
AND scope IN ({scope_placeholders})
|
|
AND (scope = 'shared' OR user_id = ?)
|
|
LIMIT ?
|
|
"""
|
|
params.extend([user_id, limit])
|
|
else:
|
|
sql_query = f"""
|
|
SELECT * FROM chunks
|
|
WHERE ({where_clause})
|
|
AND scope IN ({scope_placeholders})
|
|
LIMIT ?
|
|
"""
|
|
params.append(limit)
|
|
|
|
try:
|
|
rows = self.conn.execute(sql_query, params).fetchall()
|
|
return [
|
|
SearchResult(
|
|
path=row['path'],
|
|
start_line=row['start_line'],
|
|
end_line=row['end_line'],
|
|
score=0.5, # Fixed score for LIKE search
|
|
snippet=self._truncate_text(row['text'], 500),
|
|
source=row['source'],
|
|
user_id=row['user_id']
|
|
)
|
|
for row in rows
|
|
]
|
|
except Exception:
|
|
return []
|
|
|
|
def delete_by_path(self, path: str):
|
|
"""Delete all chunks from a file"""
|
|
self.conn.execute("""
|
|
DELETE FROM chunks WHERE path = ?
|
|
""", (path,))
|
|
self.conn.commit()
|
|
|
|
def get_file_hash(self, path: str) -> Optional[str]:
|
|
"""Get stored file hash"""
|
|
row = self.conn.execute("""
|
|
SELECT hash FROM files WHERE path = ?
|
|
""", (path,)).fetchone()
|
|
return row['hash'] if row else None
|
|
|
|
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
|
|
"""Update file metadata"""
|
|
self.conn.execute("""
|
|
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
|
|
""", (path, source, file_hash, mtime, size))
|
|
self.conn.commit()
|
|
|
|
def get_stats(self) -> Dict[str, int]:
|
|
"""Get storage statistics"""
|
|
chunks_count = self.conn.execute("""
|
|
SELECT COUNT(*) as cnt FROM chunks
|
|
""").fetchone()['cnt']
|
|
|
|
files_count = self.conn.execute("""
|
|
SELECT COUNT(*) as cnt FROM files
|
|
""").fetchone()['cnt']
|
|
|
|
return {
|
|
'chunks': chunks_count,
|
|
'files': files_count
|
|
}
|
|
|
|
def close(self):
|
|
"""Close database connection"""
|
|
if self.conn:
|
|
try:
|
|
self.conn.commit() # Ensure all changes are committed
|
|
self.conn.close()
|
|
self.conn = None # Mark as closed
|
|
except Exception as e:
|
|
print(f"⚠️ Error closing database connection: {e}")
|
|
|
|
def __del__(self):
|
|
"""Destructor to ensure connection is closed"""
|
|
try:
|
|
self.close()
|
|
except:
|
|
pass # Ignore errors during cleanup
|
|
|
|
# Helper methods
|
|
|
|
def _row_to_chunk(self, row) -> MemoryChunk:
|
|
"""Convert database row to MemoryChunk"""
|
|
return MemoryChunk(
|
|
id=row['id'],
|
|
user_id=row['user_id'],
|
|
scope=row['scope'],
|
|
source=row['source'],
|
|
path=row['path'],
|
|
start_line=row['start_line'],
|
|
end_line=row['end_line'],
|
|
text=row['text'],
|
|
embedding=json.loads(row['embedding']) if row['embedding'] else None,
|
|
hash=row['hash'],
|
|
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
|
)
|
|
|
|
@staticmethod
|
|
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
|
"""Calculate cosine similarity between two vectors"""
|
|
if len(vec1) != len(vec2):
|
|
return 0.0
|
|
|
|
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
|
norm1 = sum(a * a for a in vec1) ** 0.5
|
|
norm2 = sum(b * b for b in vec2) ** 0.5
|
|
|
|
if norm1 == 0 or norm2 == 0:
|
|
return 0.0
|
|
|
|
return dot_product / (norm1 * norm2)
|
|
|
|
@staticmethod
|
|
def _contains_cjk(text: str) -> bool:
|
|
"""Check if text contains CJK (Chinese/Japanese/Korean) characters"""
|
|
import re
|
|
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
|
|
|
@staticmethod
|
|
def _build_fts_query(raw_query: str) -> Optional[str]:
|
|
"""
|
|
Build FTS5 query from raw text
|
|
|
|
Works best for English and word-based languages.
|
|
For CJK characters, LIKE search will be used as fallback.
|
|
"""
|
|
import re
|
|
# Extract words (primarily English words and numbers)
|
|
tokens = re.findall(r'[A-Za-z0-9_]+', raw_query)
|
|
if not tokens:
|
|
return None
|
|
|
|
# Quote tokens for exact matching
|
|
quoted = [f'"{t}"' for t in tokens]
|
|
# Use OR for more flexible matching
|
|
return ' OR '.join(quoted)
|
|
|
|
@staticmethod
|
|
def _bm25_rank_to_score(rank: float) -> float:
|
|
"""Convert BM25 rank to 0-1 score"""
|
|
normalized = max(0, rank) if rank is not None else 999
|
|
return 1 / (1 + normalized)
|
|
|
|
@staticmethod
|
|
def _truncate_text(text: str, max_chars: int) -> str:
|
|
"""Truncate text to max characters"""
|
|
if len(text) <= max_chars:
|
|
return text
|
|
return text[:max_chars] + "..."
|
|
|
|
@staticmethod
|
|
def compute_hash(content: str) -> str:
|
|
"""Compute SHA256 hash of content"""
|
|
return hashlib.sha256(content.encode('utf-8')).hexdigest()
|