mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-28 17:01:08 +08:00
617 lines
21 KiB
Python
617 lines
21 KiB
Python
"""
|
|
Conversation history persistence using SQLite.
|
|
|
|
Design:
|
|
- sessions table: per-session metadata (channel_type, last_active, msg_count)
|
|
- messages table: individual messages stored as JSON, append-only
|
|
- Pruning: age-based only (sessions not updated within N days are deleted)
|
|
- Thread-safe via a single in-process lock
|
|
|
|
Storage path: ~/cow/sessions/conversations.db
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import sqlite3
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from common.log import logger
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Schema
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_DDL = """
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
session_id TEXT PRIMARY KEY,
|
|
channel_type TEXT NOT NULL DEFAULT '',
|
|
created_at INTEGER NOT NULL,
|
|
last_active INTEGER NOT NULL,
|
|
msg_count INTEGER NOT NULL DEFAULT 0
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id TEXT NOT NULL,
|
|
seq INTEGER NOT NULL,
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
UNIQUE (session_id, seq)
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_messages_session
|
|
ON messages (session_id, seq);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_sessions_last_active
|
|
ON sessions (last_active);
|
|
"""
|
|
|
|
# Migration: add channel_type column to existing databases that predate it.
|
|
_MIGRATION_ADD_CHANNEL_TYPE = """
|
|
ALTER TABLE sessions ADD COLUMN channel_type TEXT NOT NULL DEFAULT '';
|
|
"""
|
|
|
|
DEFAULT_MAX_AGE_DAYS: int = 30
|
|
|
|
|
|
def _is_visible_user_message(content: Any) -> bool:
|
|
"""
|
|
Return True when a user-role message represents actual user input
|
|
(not an internal tool_result injected by the agent loop).
|
|
"""
|
|
if isinstance(content, str):
|
|
return bool(content.strip())
|
|
if isinstance(content, list):
|
|
return any(
|
|
isinstance(b, dict) and b.get("type") == "text"
|
|
for b in content
|
|
)
|
|
return False
|
|
|
|
|
|
def _extract_display_text(content: Any) -> str:
|
|
"""
|
|
Extract the human-readable text portion from a message content value.
|
|
Returns an empty string for tool_use / tool_result blocks.
|
|
"""
|
|
if isinstance(content, str):
|
|
return content.strip()
|
|
if isinstance(content, list):
|
|
parts = [
|
|
b.get("text", "")
|
|
for b in content
|
|
if isinstance(b, dict) and b.get("type") == "text"
|
|
]
|
|
return "\n".join(p for p in parts if p).strip()
|
|
return ""
|
|
|
|
|
|
def _extract_tool_calls(content: Any) -> List[Dict[str, Any]]:
|
|
"""
|
|
Extract tool_use blocks from an assistant message content.
|
|
Returns a list of {name, arguments} dicts (result filled in later).
|
|
"""
|
|
if not isinstance(content, list):
|
|
return []
|
|
return [
|
|
{"id": b.get("id", ""), "name": b.get("name", ""), "arguments": b.get("input", {})}
|
|
for b in content
|
|
if isinstance(b, dict) and b.get("type") == "tool_use"
|
|
]
|
|
|
|
|
|
def _extract_tool_results(content: Any) -> Dict[str, str]:
|
|
"""
|
|
Extract tool_result blocks from a user message, keyed by tool_use_id.
|
|
"""
|
|
if not isinstance(content, list):
|
|
return {}
|
|
results = {}
|
|
for b in content:
|
|
if not isinstance(b, dict) or b.get("type") != "tool_result":
|
|
continue
|
|
tool_id = b.get("tool_use_id", "")
|
|
result_content = b.get("content", "")
|
|
if isinstance(result_content, list):
|
|
result_content = "\n".join(
|
|
rb.get("text", "") for rb in result_content
|
|
if isinstance(rb, dict) and rb.get("type") == "text"
|
|
)
|
|
results[tool_id] = str(result_content)
|
|
return results
|
|
|
|
|
|
def _group_into_display_turns(
|
|
rows: List[tuple],
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Convert raw (role, content_json, created_at) DB rows into display turns.
|
|
|
|
One display turn = one visible user message + one merged assistant reply.
|
|
All intermediate assistant messages (those carrying tool_use) and the final
|
|
assistant text reply produced for the same user query are collapsed into a
|
|
single assistant turn, exactly matching the live SSE rendering where tools
|
|
and the final answer appear inside the same bubble.
|
|
|
|
Grouping rules:
|
|
- A visible user message starts a new group.
|
|
- tool_result user messages are internal; their content is attached to the
|
|
matching tool_use entry via tool_use_id and they never become own turns.
|
|
- All assistant messages within a group are merged:
|
|
* tool_use blocks → tool_calls list (result filled from tool_results)
|
|
* text blocks → last non-empty text becomes the display content
|
|
"""
|
|
# ------------------------------------------------------------------ #
|
|
# Pass 1: split rows into groups, each starting with a visible user msg
|
|
# ------------------------------------------------------------------ #
|
|
# group = (user_row | None, [subsequent_rows])
|
|
# user_row: (content, created_at)
|
|
groups: List[tuple] = []
|
|
cur_user: Optional[tuple] = None
|
|
cur_rest: List[tuple] = []
|
|
started = False
|
|
|
|
for role, raw_content, created_at in rows:
|
|
try:
|
|
content = json.loads(raw_content)
|
|
except Exception:
|
|
content = raw_content
|
|
|
|
if role == "user" and _is_visible_user_message(content):
|
|
if started:
|
|
groups.append((cur_user, cur_rest))
|
|
cur_user = (content, created_at)
|
|
cur_rest = []
|
|
started = True
|
|
else:
|
|
cur_rest.append((role, content, created_at))
|
|
|
|
if started:
|
|
groups.append((cur_user, cur_rest))
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Pass 2: build display turns from each group
|
|
# ------------------------------------------------------------------ #
|
|
turns: List[Dict[str, Any]] = []
|
|
|
|
for user_row, rest in groups:
|
|
# User turn
|
|
if user_row:
|
|
content, created_at = user_row
|
|
text = _extract_display_text(content)
|
|
if text:
|
|
turns.append({"role": "user", "content": text, "created_at": created_at})
|
|
|
|
# Collect all tool_calls and tool_results from the rest of the group
|
|
all_tool_calls: List[Dict[str, Any]] = []
|
|
tool_results: Dict[str, str] = {}
|
|
final_text = ""
|
|
final_ts: Optional[int] = None
|
|
|
|
for role, content, created_at in rest:
|
|
if role == "user":
|
|
tool_results.update(_extract_tool_results(content))
|
|
elif role == "assistant":
|
|
tcs = _extract_tool_calls(content)
|
|
all_tool_calls.extend(tcs)
|
|
t = _extract_display_text(content)
|
|
if t:
|
|
final_text = t
|
|
final_ts = created_at
|
|
|
|
# Attach tool results to their matching tool_call entries
|
|
for tc in all_tool_calls:
|
|
tc["result"] = tool_results.get(tc.get("id", ""), "")
|
|
|
|
if final_text or all_tool_calls:
|
|
turns.append({
|
|
"role": "assistant",
|
|
"content": final_text,
|
|
"tool_calls": all_tool_calls,
|
|
"created_at": final_ts or (user_row[1] if user_row else 0),
|
|
})
|
|
|
|
return turns
|
|
|
|
|
|
class ConversationStore:
|
|
"""
|
|
SQLite-backed store for per-session conversation history.
|
|
|
|
Usage:
|
|
store = ConversationStore(db_path)
|
|
store.append_messages("user_123", new_messages, channel_type="feishu")
|
|
msgs = store.load_messages("user_123", max_turns=30)
|
|
"""
|
|
|
|
def __init__(self, db_path: Path):
|
|
self._db_path = db_path
|
|
self._lock = threading.Lock()
|
|
self._init_db()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API
|
|
# ------------------------------------------------------------------
|
|
|
|
def load_messages(
|
|
self,
|
|
session_id: str,
|
|
max_turns: int = 30,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Load the most recent messages for a session, for injection into the LLM.
|
|
|
|
ALL message types (user text, assistant tool_use, tool_result) are returned
|
|
in their original JSON form so the LLM can reconstruct the full context.
|
|
|
|
max_turns is a *visible-turn* count: we count only user messages whose
|
|
content is actual user text (not tool_result blocks). This prevents
|
|
tool-heavy sessions from exhausting the turn budget prematurely.
|
|
|
|
Args:
|
|
session_id: Unique session identifier.
|
|
max_turns: Maximum number of visible user-assistant turns to keep.
|
|
|
|
Returns:
|
|
Chronologically ordered list of message dicts (role, content).
|
|
"""
|
|
with self._lock:
|
|
conn = self._connect()
|
|
try:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT seq, role, content
|
|
FROM messages
|
|
WHERE session_id = ?
|
|
ORDER BY seq DESC
|
|
""",
|
|
(session_id,),
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
|
|
if not rows:
|
|
return []
|
|
|
|
# Walk newest-to-oldest counting *visible* user turns (actual user text,
|
|
# not tool_result injections). Record the seq of every visible user
|
|
# message so we can find a clean cut point later.
|
|
visible_turn_seqs: List[int] = [] # newest first
|
|
for seq, role, raw_content in rows:
|
|
if role != "user":
|
|
continue
|
|
try:
|
|
content = json.loads(raw_content)
|
|
except Exception:
|
|
content = raw_content
|
|
if _is_visible_user_message(content):
|
|
visible_turn_seqs.append(seq)
|
|
|
|
# Determine the seq of the oldest visible user message we want to keep.
|
|
# If the total turns fit within max_turns, keep everything.
|
|
if len(visible_turn_seqs) <= max_turns:
|
|
cutoff_seq = None # keep all
|
|
else:
|
|
# The Nth visible user message (0-indexed) is the oldest we keep.
|
|
cutoff_seq = visible_turn_seqs[max_turns - 1]
|
|
|
|
# Build result in chronological order, starting from cutoff.
|
|
# IMPORTANT: we start exactly at cutoff_seq (the visible user message),
|
|
# never mid-group, so tool_use / tool_result pairs are always complete.
|
|
result = []
|
|
for seq, role, raw_content in reversed(rows):
|
|
if cutoff_seq is not None and seq < cutoff_seq:
|
|
continue
|
|
try:
|
|
content = json.loads(raw_content)
|
|
except Exception:
|
|
content = raw_content
|
|
result.append({"role": role, "content": content})
|
|
return result
|
|
|
|
def append_messages(
|
|
self,
|
|
session_id: str,
|
|
messages: List[Dict[str, Any]],
|
|
channel_type: str = "",
|
|
) -> None:
|
|
"""
|
|
Append new messages to a session's history.
|
|
|
|
Seq numbers continue from the session's current maximum, so
|
|
concurrent callers on distinct sessions never collide.
|
|
|
|
Args:
|
|
session_id: Unique session identifier.
|
|
messages: List of message dicts to append.
|
|
channel_type: Source channel (e.g. "feishu", "web", "wechat").
|
|
Only written on session creation; ignored on update.
|
|
"""
|
|
if not messages:
|
|
return
|
|
|
|
now = int(time.time())
|
|
with self._lock:
|
|
conn = self._connect()
|
|
try:
|
|
with conn:
|
|
# Upsert session row.
|
|
# channel_type is set only on INSERT (first time); subsequent
|
|
# appends just update last_active to avoid overwriting the value.
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO sessions
|
|
(session_id, channel_type, created_at, last_active, msg_count)
|
|
VALUES (?, ?, ?, ?, 0)
|
|
ON CONFLICT(session_id) DO UPDATE SET
|
|
last_active = excluded.last_active
|
|
""",
|
|
(session_id, channel_type, now, now),
|
|
)
|
|
|
|
# Determine starting seq for the new batch.
|
|
row = conn.execute(
|
|
"SELECT COALESCE(MAX(seq), -1) FROM messages WHERE session_id = ?",
|
|
(session_id,),
|
|
).fetchone()
|
|
next_seq = row[0] + 1
|
|
|
|
for msg in messages:
|
|
role = msg.get("role", "")
|
|
content = json.dumps(
|
|
msg.get("content", ""), ensure_ascii=False
|
|
)
|
|
conn.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO messages
|
|
(session_id, seq, role, content, created_at)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""",
|
|
(session_id, next_seq, role, content, now),
|
|
)
|
|
next_seq += 1
|
|
|
|
conn.execute(
|
|
"""
|
|
UPDATE sessions
|
|
SET msg_count = (
|
|
SELECT COUNT(*) FROM messages WHERE session_id = ?
|
|
)
|
|
WHERE session_id = ?
|
|
""",
|
|
(session_id, session_id),
|
|
)
|
|
finally:
|
|
conn.close()
|
|
|
|
def clear_session(self, session_id: str) -> None:
|
|
"""Delete all messages and the session record for a given session_id."""
|
|
with self._lock:
|
|
conn = self._connect()
|
|
try:
|
|
with conn:
|
|
conn.execute(
|
|
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
|
)
|
|
conn.execute(
|
|
"DELETE FROM sessions WHERE session_id = ?", (session_id,)
|
|
)
|
|
finally:
|
|
conn.close()
|
|
|
|
def cleanup_old_sessions(self, max_age_days: Optional[int] = None) -> int:
|
|
"""
|
|
Delete sessions that have not been active within max_age_days.
|
|
|
|
Args:
|
|
max_age_days: Override the default retention period.
|
|
|
|
Returns:
|
|
Number of sessions deleted.
|
|
"""
|
|
try:
|
|
from config import conf
|
|
max_age = max_age_days or conf().get(
|
|
"conversation_max_age_days", DEFAULT_MAX_AGE_DAYS
|
|
)
|
|
except Exception:
|
|
max_age = max_age_days or DEFAULT_MAX_AGE_DAYS
|
|
|
|
cutoff = int(time.time()) - max_age * 86400
|
|
deleted = 0
|
|
|
|
with self._lock:
|
|
conn = self._connect()
|
|
try:
|
|
with conn:
|
|
stale = conn.execute(
|
|
"SELECT session_id FROM sessions WHERE last_active < ?",
|
|
(cutoff,),
|
|
).fetchall()
|
|
for (sid,) in stale:
|
|
conn.execute(
|
|
"DELETE FROM messages WHERE session_id = ?", (sid,)
|
|
)
|
|
conn.execute(
|
|
"DELETE FROM sessions WHERE session_id = ?", (sid,)
|
|
)
|
|
deleted += 1
|
|
finally:
|
|
conn.close()
|
|
|
|
if deleted:
|
|
logger.info(f"[ConversationStore] Pruned {deleted} expired sessions")
|
|
return deleted
|
|
|
|
def load_history_page(
|
|
self,
|
|
session_id: str,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Load a page of conversation history for UI display, grouped into turns.
|
|
|
|
Each "turn" maps to one of:
|
|
- A user message (role="user", content=str)
|
|
- An assistant message (role="assistant", content=str,
|
|
tool_calls=[{name, arguments, result}] when tools were used)
|
|
|
|
Internal tool_result user messages are merged into the preceding
|
|
assistant entry's tool_calls list and never appear as standalone items.
|
|
|
|
Pages are numbered from 1 (most recent). Messages within a page are
|
|
returned in chronological order.
|
|
|
|
Returns:
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "user" | "assistant",
|
|
"content": str,
|
|
"tool_calls": [...], # assistant only, may be []
|
|
"created_at": int,
|
|
},
|
|
...
|
|
],
|
|
"total": <visible turn count>,
|
|
"page": <current page>,
|
|
"page_size": <page_size>,
|
|
"has_more": bool,
|
|
}
|
|
"""
|
|
page = max(1, page)
|
|
with self._lock:
|
|
conn = self._connect()
|
|
try:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT role, content, created_at
|
|
FROM messages
|
|
WHERE session_id = ?
|
|
ORDER BY seq ASC
|
|
""",
|
|
(session_id,),
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
|
|
visible = _group_into_display_turns(rows)
|
|
|
|
total = len(visible)
|
|
offset = (page - 1) * page_size
|
|
page_items = list(reversed(visible))[offset: offset + page_size]
|
|
page_items = list(reversed(page_items))
|
|
|
|
return {
|
|
"messages": page_items,
|
|
"total": total,
|
|
"page": page,
|
|
"page_size": page_size,
|
|
"has_more": offset + page_size < total,
|
|
}
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Return basic stats keyed by channel_type, for monitoring."""
|
|
with self._lock:
|
|
conn = self._connect()
|
|
try:
|
|
total_sessions = conn.execute(
|
|
"SELECT COUNT(*) FROM sessions"
|
|
).fetchone()[0]
|
|
total_messages = conn.execute(
|
|
"SELECT COUNT(*) FROM messages"
|
|
).fetchone()[0]
|
|
by_channel = conn.execute(
|
|
"""
|
|
SELECT channel_type, COUNT(*) as cnt
|
|
FROM sessions
|
|
GROUP BY channel_type
|
|
ORDER BY cnt DESC
|
|
"""
|
|
).fetchall()
|
|
return {
|
|
"total_sessions": total_sessions,
|
|
"total_messages": total_messages,
|
|
"by_channel": {row[0] or "unknown": row[1] for row in by_channel},
|
|
}
|
|
finally:
|
|
conn.close()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _init_db(self) -> None:
|
|
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
conn = self._connect()
|
|
try:
|
|
conn.executescript(_DDL)
|
|
conn.commit()
|
|
self._migrate(conn)
|
|
finally:
|
|
conn.close()
|
|
|
|
def _migrate(self, conn: sqlite3.Connection) -> None:
|
|
"""Apply incremental schema migrations on existing databases."""
|
|
cols = {
|
|
row[1]
|
|
for row in conn.execute("PRAGMA table_info(sessions)").fetchall()
|
|
}
|
|
if "channel_type" not in cols:
|
|
try:
|
|
conn.execute(_MIGRATION_ADD_CHANNEL_TYPE)
|
|
conn.commit()
|
|
logger.info("[ConversationStore] Migrated: added channel_type column")
|
|
except Exception as e:
|
|
logger.warning(f"[ConversationStore] Migration failed: {e}")
|
|
|
|
def _connect(self) -> sqlite3.Connection:
|
|
conn = sqlite3.connect(str(self._db_path), timeout=10)
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
conn.execute("PRAGMA synchronous=NORMAL")
|
|
return conn
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Singleton
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_store_instance: Optional[ConversationStore] = None
|
|
_store_lock = threading.Lock()
|
|
|
|
|
|
def get_conversation_store() -> ConversationStore:
|
|
"""
|
|
Return the process-wide ConversationStore singleton.
|
|
|
|
Reuses the long-term memory database so the project stays with a single
|
|
SQLite file: ~/cow/memory/long-term/index.db
|
|
The conversation tables (sessions / messages) are separate from the
|
|
memory tables (memory_chunks / file_metadata) — no conflicts.
|
|
"""
|
|
global _store_instance
|
|
if _store_instance is not None:
|
|
return _store_instance
|
|
|
|
with _store_lock:
|
|
if _store_instance is not None:
|
|
return _store_instance
|
|
|
|
try:
|
|
from agent.memory.config import get_default_memory_config
|
|
db_path = get_default_memory_config().get_db_path()
|
|
except Exception:
|
|
from common.utils import expand_path
|
|
db_path = Path(expand_path("~/cow")) / "memory" / "long-term" / "index.db"
|
|
|
|
_store_instance = ConversationStore(db_path)
|
|
logger.debug(f"[ConversationStore] Using shared DB at: {db_path}")
|
|
return _store_instance
|