Files
chatgpt-on-wechat/agent/memory/storage.py
2026-02-03 12:19:36 +08:00

590 lines
20 KiB
Python

"""
Storage layer for memory using SQLite + FTS5
Provides vector and keyword search capabilities
"""
from __future__ import annotations
import sqlite3
import json
import hashlib
from typing import List, Dict, Optional, Any
from pathlib import Path
from dataclasses import dataclass
@dataclass
class MemoryChunk:
"""Represents a memory chunk with text and embedding"""
id: str
user_id: Optional[str]
scope: str # "shared" | "user" | "session"
source: str # "memory" | "session"
path: str
start_line: int
end_line: int
text: str
embedding: Optional[List[float]]
hash: str
metadata: Optional[Dict[str, Any]] = None
@dataclass
class SearchResult:
"""Search result with score and snippet"""
path: str
start_line: int
end_line: int
score: float
snippet: str
source: str
user_id: Optional[str] = None
class MemoryStorage:
"""SQLite-based storage with FTS5 for keyword search"""
def __init__(self, db_path: Path):
self.db_path = db_path
self.conn: Optional[sqlite3.Connection] = None
self.fts5_available = False # Track FTS5 availability
self._init_db()
def _check_fts5_support(self) -> bool:
"""Check if SQLite has FTS5 support"""
try:
self.conn.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(test)")
self.conn.execute("DROP TABLE IF EXISTS fts5_test")
return True
except sqlite3.OperationalError as e:
if "no such module: fts5" in str(e):
return False
raise
def _init_db(self):
"""Initialize database with schema"""
try:
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
# Check FTS5 support
self.fts5_available = self._check_fts5_support()
if not self.fts5_available:
from common.log import logger
logger.debug("[MemoryStorage] FTS5 not available, using LIKE-based keyword search")
# Check database integrity
try:
result = self.conn.execute("PRAGMA integrity_check").fetchone()
if result[0] != 'ok':
print(f"⚠️ Database integrity check failed: {result[0]}")
print(f" Recreating database...")
self.conn.close()
self.conn = None
# Remove corrupted database
self.db_path.unlink(missing_ok=True)
# Remove WAL files
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
# Reconnect to create new database
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
except sqlite3.DatabaseError:
# Database is corrupted, recreate it
print(f"⚠️ Database is corrupted, recreating...")
if self.conn:
self.conn.close()
self.conn = None
self.db_path.unlink(missing_ok=True)
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
# Enable WAL mode for better concurrency
self.conn.execute("PRAGMA journal_mode=WAL")
# Set busy timeout to avoid "database is locked" errors
self.conn.execute("PRAGMA busy_timeout=5000")
except Exception as e:
print(f"⚠️ Unexpected error during database initialization: {e}")
raise
# Create chunks table with embeddings
self.conn.execute("""
CREATE TABLE IF NOT EXISTS chunks (
id TEXT PRIMARY KEY,
user_id TEXT,
scope TEXT NOT NULL DEFAULT 'shared',
source TEXT NOT NULL DEFAULT 'memory',
path TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL,
text TEXT NOT NULL,
embedding TEXT,
hash TEXT NOT NULL,
metadata TEXT,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
)
""")
# Create indexes
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_user
ON chunks(user_id)
""")
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_scope
ON chunks(scope)
""")
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_hash
ON chunks(path, hash)
""")
# Create FTS5 virtual table for keyword search (only if supported)
if self.fts5_available:
# Use default unicode61 tokenizer (stable and compatible)
# For CJK support, we'll use LIKE queries as fallback
self.conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
text,
id UNINDEXED,
user_id UNINDEXED,
path UNINDEXED,
source UNINDEXED,
scope UNINDEXED,
content='chunks',
content_rowid='rowid'
)
""")
# Create triggers to keep FTS in sync
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
DELETE FROM chunks_fts WHERE rowid = old.rowid;
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
UPDATE chunks_fts SET text = new.text, id = new.id,
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
WHERE rowid = new.rowid;
END
""")
# Create files metadata table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS files (
path TEXT PRIMARY KEY,
source TEXT NOT NULL DEFAULT 'memory',
hash TEXT NOT NULL,
mtime INTEGER NOT NULL,
size INTEGER NOT NULL,
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
)
""")
self.conn.commit()
def save_chunk(self, chunk: MemoryChunk):
"""Save a memory chunk"""
self.conn.execute("""
INSERT OR REPLACE INTO chunks
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
""", (
chunk.id,
chunk.user_id,
chunk.scope,
chunk.source,
chunk.path,
chunk.start_line,
chunk.end_line,
chunk.text,
json.dumps(chunk.embedding) if chunk.embedding else None,
chunk.hash,
json.dumps(chunk.metadata) if chunk.metadata else None
))
self.conn.commit()
def save_chunks_batch(self, chunks: List[MemoryChunk]):
"""Save multiple chunks in a batch"""
self.conn.executemany("""
INSERT OR REPLACE INTO chunks
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
""", [
(
c.id, c.user_id, c.scope, c.source, c.path,
c.start_line, c.end_line, c.text,
json.dumps(c.embedding) if c.embedding else None,
c.hash,
json.dumps(c.metadata) if c.metadata else None
)
for c in chunks
])
self.conn.commit()
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
"""Get a chunk by ID"""
row = self.conn.execute("""
SELECT * FROM chunks WHERE id = ?
""", (chunk_id,)).fetchone()
if not row:
return None
return self._row_to_chunk(row)
def search_vector(
self,
query_embedding: List[float],
user_id: Optional[str] = None,
scopes: List[str] = None,
limit: int = 10
) -> List[SearchResult]:
"""
Vector similarity search using in-memory cosine similarity
(sqlite-vec can be added later for better performance)
"""
if scopes is None:
scopes = ["shared"]
if user_id:
scopes.append("user")
# Build query
scope_placeholders = ','.join('?' * len(scopes))
params = scopes
if user_id:
query = f"""
SELECT * FROM chunks
WHERE scope IN ({scope_placeholders})
AND (scope = 'shared' OR user_id = ?)
AND embedding IS NOT NULL
"""
params.append(user_id)
else:
query = f"""
SELECT * FROM chunks
WHERE scope IN ({scope_placeholders})
AND embedding IS NOT NULL
"""
rows = self.conn.execute(query, params).fetchall()
# Calculate cosine similarity
results = []
for row in rows:
embedding = json.loads(row['embedding'])
similarity = self._cosine_similarity(query_embedding, embedding)
if similarity > 0:
results.append((similarity, row))
# Sort by similarity and limit
results.sort(key=lambda x: x[0], reverse=True)
results = results[:limit]
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=score,
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for score, row in results
]
def search_keyword(
self,
query: str,
user_id: Optional[str] = None,
scopes: List[str] = None,
limit: int = 10
) -> List[SearchResult]:
"""
Keyword search using FTS5 + LIKE fallback
Strategy:
1. If FTS5 available: Try FTS5 search first (good for English and word-based languages)
2. If no FTS5 or no results and query contains CJK: Use LIKE search
"""
if scopes is None:
scopes = ["shared"]
if user_id:
scopes.append("user")
# Try FTS5 search first (if available)
if self.fts5_available:
fts_results = self._search_fts5(query, user_id, scopes, limit)
if fts_results:
return fts_results
# Fallback to LIKE search (always for CJK, or if FTS5 not available)
if not self.fts5_available or MemoryStorage._contains_cjk(query):
return self._search_like(query, user_id, scopes, limit)
return []
def _search_fts5(
self,
query: str,
user_id: Optional[str],
scopes: List[str],
limit: int
) -> List[SearchResult]:
"""FTS5 full-text search"""
fts_query = self._build_fts_query(query)
if not fts_query:
return []
scope_placeholders = ','.join('?' * len(scopes))
params = [fts_query] + scopes
if user_id:
sql_query = f"""
SELECT chunks.*, bm25(chunks_fts) as rank
FROM chunks_fts
JOIN chunks ON chunks.id = chunks_fts.id
WHERE chunks_fts MATCH ?
AND chunks.scope IN ({scope_placeholders})
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
ORDER BY rank
LIMIT ?
"""
params.extend([user_id, limit])
else:
sql_query = f"""
SELECT chunks.*, bm25(chunks_fts) as rank
FROM chunks_fts
JOIN chunks ON chunks.id = chunks_fts.id
WHERE chunks_fts MATCH ?
AND chunks.scope IN ({scope_placeholders})
ORDER BY rank
LIMIT ?
"""
params.append(limit)
try:
rows = self.conn.execute(sql_query, params).fetchall()
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=self._bm25_rank_to_score(row['rank']),
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for row in rows
]
except Exception:
return []
def _search_like(
self,
query: str,
user_id: Optional[str],
scopes: List[str],
limit: int
) -> List[SearchResult]:
"""LIKE-based search for CJK characters"""
import re
# Extract CJK words (2+ characters)
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
if not cjk_words:
return []
scope_placeholders = ','.join('?' * len(scopes))
# Build LIKE conditions for each word
like_conditions = []
params = []
for word in cjk_words:
like_conditions.append("text LIKE ?")
params.append(f'%{word}%')
where_clause = ' OR '.join(like_conditions)
params.extend(scopes)
if user_id:
sql_query = f"""
SELECT * FROM chunks
WHERE ({where_clause})
AND scope IN ({scope_placeholders})
AND (scope = 'shared' OR user_id = ?)
LIMIT ?
"""
params.extend([user_id, limit])
else:
sql_query = f"""
SELECT * FROM chunks
WHERE ({where_clause})
AND scope IN ({scope_placeholders})
LIMIT ?
"""
params.append(limit)
try:
rows = self.conn.execute(sql_query, params).fetchall()
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=0.5, # Fixed score for LIKE search
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for row in rows
]
except Exception:
return []
def delete_by_path(self, path: str):
"""Delete all chunks from a file"""
self.conn.execute("""
DELETE FROM chunks WHERE path = ?
""", (path,))
self.conn.commit()
def get_file_hash(self, path: str) -> Optional[str]:
"""Get stored file hash"""
row = self.conn.execute("""
SELECT hash FROM files WHERE path = ?
""", (path,)).fetchone()
return row['hash'] if row else None
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
"""Update file metadata"""
self.conn.execute("""
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
""", (path, source, file_hash, mtime, size))
self.conn.commit()
def get_stats(self) -> Dict[str, int]:
"""Get storage statistics"""
chunks_count = self.conn.execute("""
SELECT COUNT(*) as cnt FROM chunks
""").fetchone()['cnt']
files_count = self.conn.execute("""
SELECT COUNT(*) as cnt FROM files
""").fetchone()['cnt']
return {
'chunks': chunks_count,
'files': files_count
}
def close(self):
"""Close database connection"""
if self.conn:
try:
self.conn.commit() # Ensure all changes are committed
self.conn.close()
self.conn = None # Mark as closed
except Exception as e:
print(f"⚠️ Error closing database connection: {e}")
def __del__(self):
"""Destructor to ensure connection is closed"""
try:
self.close()
except:
pass # Ignore errors during cleanup
# Helper methods
def _row_to_chunk(self, row) -> MemoryChunk:
"""Convert database row to MemoryChunk"""
return MemoryChunk(
id=row['id'],
user_id=row['user_id'],
scope=row['scope'],
source=row['source'],
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
text=row['text'],
embedding=json.loads(row['embedding']) if row['embedding'] else None,
hash=row['hash'],
metadata=json.loads(row['metadata']) if row['metadata'] else None
)
@staticmethod
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""Calculate cosine similarity between two vectors"""
if len(vec1) != len(vec2):
return 0.0
dot_product = sum(a * b for a, b in zip(vec1, vec2))
norm1 = sum(a * a for a in vec1) ** 0.5
norm2 = sum(b * b for b in vec2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
@staticmethod
def _contains_cjk(text: str) -> bool:
"""Check if text contains CJK (Chinese/Japanese/Korean) characters"""
import re
return bool(re.search(r'[\u4e00-\u9fff]', text))
@staticmethod
def _build_fts_query(raw_query: str) -> Optional[str]:
"""
Build FTS5 query from raw text
Works best for English and word-based languages.
For CJK characters, LIKE search will be used as fallback.
"""
import re
# Extract words (primarily English words and numbers)
tokens = re.findall(r'[A-Za-z0-9_]+', raw_query)
if not tokens:
return None
# Quote tokens for exact matching
quoted = [f'"{t}"' for t in tokens]
# Use OR for more flexible matching
return ' OR '.join(quoted)
@staticmethod
def _bm25_rank_to_score(rank: float) -> float:
"""Convert BM25 rank to 0-1 score"""
normalized = max(0, rank) if rank is not None else 999
return 1 / (1 + normalized)
@staticmethod
def _truncate_text(text: str, max_chars: int) -> str:
"""Truncate text to max characters"""
if len(text) <= max_chars:
return text
return text[:max_chars] + "..."
@staticmethod
def compute_hash(content: str) -> str:
"""Compute SHA256 hash of content"""
return hashlib.sha256(content.encode('utf-8')).hexdigest()