diff --git a/bot.py b/bot.py index 0720ac4..715a24c 100644 --- a/bot.py +++ b/bot.py @@ -89,8 +89,13 @@ class BubblesBot: # 初始化 Agent Loop 系统 self.tool_registry = create_default_registry() self.agent_loop = AgentLoop(self.tool_registry, max_iterations=20) - self.session_manager = SessionManager(self.message_summary, self.bot_id) + self.session_manager = SessionManager( + message_summary=self.message_summary, + bot_id=self.bot_id, + db_path=db_path, + ) self.LOG.info(f"Agent Loop 系统已初始化,工具: {self.tool_registry.get_tool_names()}") + self.LOG.info(f"Session 管理器已初始化,已加载 {len(self.session_manager._cache)} 个 session") def _init_chat_models(self) -> None: """初始化所有 AI 模型""" @@ -243,9 +248,24 @@ class BubblesBot: chat_id = msg.get_chat_id() content = msg.content - # 获取会话 - session_key = f"{self.channel.name}:{chat_id}" - session = self.session_manager.get_or_create(session_key, max_history=30) + # 处理 session 命令 + if content.startswith("/session"): + await self._handle_session_command(msg, content) + return + + # 获取会话(通过别名解析,支持跨 Channel 统一会话) + session_alias = f"{self.channel.name}:{chat_id}" + session = self.session_manager.get_or_create(session_alias) + + # 从 session 配置获取设置 + session_config = session.config + max_history = session_config.max_history or 30 + + # 选择模型(优先使用 session 绑定的模型) + chat_model = self.chat + if session_config.model_id is not None and session_config.model_id in self.chat_models: + chat_model = self.chat_models[session_config.model_id] + self.LOG.debug(f"使用 session 绑定的模型: {session_config.model_id}") # 构建用户消息 current_time = time_mod.strftime("%H:%M", time_mod.localtime()) @@ -261,16 +281,19 @@ class BubblesBot: "- 日常闲聊、观点讨论、情感交流 → 直接回复,不需要调用任何工具\n" ) - # 获取人设 - persona_text = None - if self.persona_manager: + # 获取人设(优先使用 session 绑定的,其次从 persona_manager) + persona_text = session_config.persona + if not persona_text and self.persona_manager: try: persona_text = self.persona_manager.get_persona(chat_id) except Exception: pass - if persona_text: - system_prompt = build_persona_system_prompt(self.chat, persona_text) + tool_guidance + # 获取 system prompt(优先使用 session 绑定的) + if session_config.system_prompt: + system_prompt = session_config.system_prompt + tool_guidance + elif persona_text: + system_prompt = build_persona_system_prompt(chat_model, persona_text) + tool_guidance else: system_prompt = tool_guidance @@ -283,7 +306,7 @@ class BubblesBot: messages.append({"role": "system", "content": f"Current time is: {now_time}"}) # 添加历史消息 - history = session.get_history(30) + history = session.get_history(max_history) for hist_msg in history: role = hist_msg.get("role", "user") hist_content = hist_msg.get("content", "") @@ -320,14 +343,13 @@ class BubblesBot: robot=self, logger=self.LOG, config=self.config, - specific_max_history=30, + specific_max_history=max_history, persona=persona_text, _send_text_func=send_func, ) # 执行 Agent Loop try: - chat_model = self.chat if not chat_model: await self.channel.send_text("抱歉,没有可用的 AI 模型。", chat_id) return @@ -358,6 +380,103 @@ class BubblesBot: self.LOG.error(f"Agent Loop 执行失败: {e}", exc_info=True) await self.channel.send_text("抱歉,处理消息时出错了。", chat_id) + async def _handle_session_command(self, msg: Message, content: str) -> None: + """处理 session 相关命令 + + 命令格式: + /session bind - 绑定当前会话到指定 session + /session unbind - 解除当前会话的绑定 + /session info - 查看当前 session 信息 + /session list - 列出所有 session + /session model - 设置当前 session 的模型 + /session clear - 清空当前 session 的消息历史 + """ + chat_id = msg.get_chat_id() + alias = f"{self.channel.name}:{chat_id}" + parts = content.split() + + if len(parts) < 2: + await self.channel.send_text( + "用法:\n" + " /session bind - 绑定到 session\n" + " /session unbind - 解除绑定\n" + " /session info - 查看信息\n" + " /session list - 列出所有\n" + " /session model - 设置模型\n" + " /session clear - 清空历史", + chat_id, + ) + return + + cmd = parts[1].lower() + + if cmd == "bind" and len(parts) >= 3: + session_key = parts[2] + session = self.session_manager.bind(session_key, alias) + await self.channel.send_text( + f"已绑定到 session: {session_key}\n" + f"别名: {', '.join(session.aliases)}", + chat_id, + ) + + elif cmd == "unbind": + if self.session_manager.unbind(alias): + await self.channel.send_text("已解除绑定", chat_id) + else: + await self.channel.send_text("当前会话未绑定到任何 session", chat_id) + + elif cmd == "info": + session = self.session_manager.get(alias) + if session: + model_name = "默认" + if session.config.model_id is not None: + model = self.chat_models.get(session.config.model_id) + model_name = model.__class__.__name__ if model else f"ID:{session.config.model_id}" + + info = ( + f"Session Key: {session.key}\n" + f"别名: {', '.join(session.aliases) or '无'}\n" + f"模型: {model_name}\n" + f"历史限制: {session.config.max_history}\n" + f"消息数: {len(session.messages)}\n" + f"人设: {'已设置' if session.config.persona else '未设置'}" + ) + await self.channel.send_text(info, chat_id) + else: + await self.channel.send_text("当前会话未创建 session", chat_id) + + elif cmd == "list": + sessions = self.session_manager.list_sessions() + if sessions: + lines = ["所有 Session:"] + for s in sessions: + aliases = ", ".join(s["aliases"]) if s["aliases"] else "无别名" + lines.append(f" {s['key']} ({aliases}) - {s['message_count']} 条消息") + await self.channel.send_text("\n".join(lines), chat_id) + else: + await self.channel.send_text("暂无 session", chat_id) + + elif cmd == "model" and len(parts) >= 3: + try: + model_id = int(parts[2]) + if model_id in self.chat_models: + self.session_manager.set_config(alias, model_id=model_id) + model_name = self.chat_models[model_id].__class__.__name__ + await self.channel.send_text(f"已设置模型: {model_name}", chat_id) + else: + available = ", ".join(str(k) for k in self.chat_models.keys()) + await self.channel.send_text(f"无效的模型 ID,可用: {available}", chat_id) + except ValueError: + await self.channel.send_text("模型 ID 必须是数字", chat_id) + + elif cmd == "clear": + session = self.session_manager.get_or_create(alias) + session.clear() + await self.channel.send_text("已清空消息历史", chat_id) + + else: + await self.channel.send_text(f"未知命令: {cmd}", chat_id) + def cleanup(self) -> None: """清理资源""" self.LOG.info("正在清理 BubblesBot 资源...") diff --git a/robot.py b/robot.py index c3ed1a1..4bb7ba3 100644 --- a/robot.py +++ b/robot.py @@ -297,8 +297,13 @@ class Robot(Job): # 初始化 Agent Loop 系统 self.tool_registry = create_default_registry() self.agent_loop = AgentLoop(self.tool_registry, max_iterations=20) - self.session_manager = SessionManager(self.message_summary, self.wxid) + self.session_manager = SessionManager( + message_summary=self.message_summary, + bot_id=self.wxid, + db_path=db_path, + ) self.LOG.info(f"Agent Loop 系统已初始化,工具: {self.tool_registry.get_tool_names()}") + self.LOG.info(f"Session 管理器已初始化,已加载 {len(self.session_manager._cache)} 个 session") @staticmethod def value_check(args: dict) -> bool: @@ -659,8 +664,23 @@ class Robot(Job): reasoning_requested = bool(getattr(ctx, 'reasoning_requested', False)) or force_reasoning is_auto_random_reply = bool(getattr(ctx, 'auto_random_reply', False)) - # 选择模型 + # 获取或创建会话(通过别名解析,支持跨 Channel 统一会话) + chat_id = ctx.get_receiver() + session_alias = f"wechat:{chat_id}" + specific_max_history = getattr(ctx, 'specific_max_history', 30) or 30 + session = self.session_manager.get_or_create(session_alias, max_history=specific_max_history) + + # 从 session 配置获取设置 + session_config = session.config + if session_config.max_history: + specific_max_history = session_config.max_history + + # 选择模型(优先使用 session 绑定的模型,其次是上下文指定的) chat_model = getattr(ctx, 'chat', None) or self.chat + if session_config.model_id is not None and session_config.model_id in self.chat_models: + chat_model = self.chat_models[session_config.model_id] + self.LOG.debug(f"使用 session 绑定的模型: {session_config.model_id}") + if reasoning_requested: if force_reasoning: self.LOG.info("群配置了 force_reasoning,将使用推理模型。") @@ -678,16 +698,6 @@ class Robot(Job): await self.send_text_async("抱歉,我现在无法进行对话。", ctx.get_receiver()) return False - # 获取历史消息限制 - specific_max_history = getattr(ctx, 'specific_max_history', 30) - if specific_max_history is None: - specific_max_history = 30 - - # 获取或创建会话 - chat_id = ctx.get_receiver() - session_key = f"wechat:{chat_id}" - session = self.session_manager.get_or_create(session_key, max_history=specific_max_history) - # 构建用户消息 sender_name = ctx.sender_name content = ctx.text @@ -723,8 +733,8 @@ class Robot(Job): "请只针对该用户进行回复。" ) - # 构建系统提示 - persona_text = getattr(ctx, 'persona', None) + # 构建系统提示(优先使用 session 配置) + persona_text = session_config.persona or getattr(ctx, 'persona', None) tool_guidance = "" if not is_auto_random_reply: tool_guidance = ( @@ -737,7 +747,10 @@ class Robot(Job): "你可以在一次对话中多次调用工具。" ) - if persona_text: + # 优先使用 session 绑定的 system_prompt + if session_config.system_prompt: + system_prompt = session_config.system_prompt + tool_guidance + elif persona_text: try: base_prompt = build_persona_system_prompt(chat_model, persona_text) system_prompt = base_prompt + tool_guidance if base_prompt else tool_guidance or None diff --git a/session/__init__.py b/session/__init__.py index 7ae4a25..9851f6e 100644 --- a/session/__init__.py +++ b/session/__init__.py @@ -1,4 +1,4 @@ # session/__init__.py -from .manager import Session, SessionManager +from .manager import Session, SessionManager, SessionConfig -__all__ = ["Session", "SessionManager"] +__all__ = ["Session", "SessionManager", "SessionConfig"] diff --git a/session/manager.py b/session/manager.py index 6c958f5..b2d211f 100644 --- a/session/manager.py +++ b/session/manager.py @@ -1,23 +1,59 @@ # session/manager.py -"""会话管理器""" +"""增强版会话管理器 - 支持跨 Channel 统一会话""" -from dataclasses import dataclass, field +import json +import sqlite3 +import logging +from dataclasses import dataclass, field, asdict from datetime import datetime from typing import Any, TYPE_CHECKING if TYPE_CHECKING: from function.func_summary import MessageSummary +logger = logging.getLogger("SessionManager") + + +@dataclass +class SessionConfig: + """Session 配置 - 绑定到会话的设置""" + + model_id: int | None = None # 绑定的模型 ID + system_prompt: str | None = None # 自定义 system prompt + persona: str | None = None # 人设文本 + max_history: int = 30 # 历史消息限制 + extra: dict = field(default_factory=dict) # 扩展配置 + + def to_dict(self) -> dict: + return { + "model_id": self.model_id, + "system_prompt": self.system_prompt, + "persona": self.persona, + "max_history": self.max_history, + "extra": self.extra, + } + + @classmethod + def from_dict(cls, data: dict) -> "SessionConfig": + return cls( + model_id=data.get("model_id"), + system_prompt=data.get("system_prompt"), + persona=data.get("persona"), + max_history=data.get("max_history", 30), + extra=data.get("extra", {}), + ) + @dataclass class Session: - """会话对象 - 管理单个对话的状态""" + """增强版会话对象 - 管理单个对话的完整状态""" - key: str # "wechat:{chat_id}" + key: str # 统一会话 key (如 "user:john" 或 "group:test") + config: SessionConfig = field(default_factory=SessionConfig) messages: list[dict[str, Any]] = field(default_factory=list) + aliases: set[str] = field(default_factory=set) # 别名集合 (如 {"wechat:wxid_xxx", "local:john"}) created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) - metadata: dict[str, Any] = field(default_factory=dict) def add_message(self, role: str, content: str, **kwargs) -> None: """添加消息""" @@ -57,49 +93,158 @@ class Session: ) self.updated_at = datetime.now() - def get_history(self, max_messages: int = 30) -> list[dict]: + def get_history(self, max_messages: int | None = None) -> list[dict]: """获取最近的消息历史""" - return self.messages[-max_messages:] + limit = max_messages or self.config.max_history + return self.messages[-limit:] def clear(self) -> None: """清空消息""" self.messages.clear() self.updated_at = datetime.now() + def bind_alias(self, alias: str) -> None: + """绑定别名""" + self.aliases.add(alias) + self.updated_at = datetime.now() + + def unbind_alias(self, alias: str) -> None: + """解绑别名""" + self.aliases.discard(alias) + self.updated_at = datetime.now() + class SessionManager: - """会话管理器 - 与 MessageSummary 协作加载历史""" + """增强版会话管理器 - def __init__(self, message_summary: "MessageSummary", bot_wxid: str): + 支持: + - 跨 Channel 统一会话(通过别名映射) + - Session 配置持久化 + - 绑定模型/人设/system prompt + """ + + def __init__( + self, + message_summary: "MessageSummary | None" = None, + bot_id: str = "", + db_path: str = "data/message_history.db", + ): self.message_summary = message_summary - self.bot_wxid = bot_wxid - self._cache: dict[str, Session] = {} + self.bot_id = bot_id + self.db_path = db_path + self._cache: dict[str, Session] = {} # key -> Session + self._alias_map: dict[str, str] = {} # alias -> key + self._init_db() + self._load_sessions() - def get_or_create(self, key: str, max_history: int = 100) -> Session: + def _init_db(self) -> None: + """初始化数据库表""" + try: + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS sessions ( + key TEXT PRIMARY KEY, + config TEXT NOT NULL DEFAULT '{}', + aliases TEXT NOT NULL DEFAULT '[]', + created_at TEXT, + updated_at TEXT + ) + """) + conn.commit() + except Exception as e: + logger.error(f"初始化 session 表失败: {e}") + + def _load_sessions(self) -> None: + """从数据库加载所有 session 配置""" + try: + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute("SELECT * FROM sessions") + for row in cursor: + key = row["key"] + config_data = json.loads(row["config"] or "{}") + aliases_data = json.loads(row["aliases"] or "[]") + + session = Session( + key=key, + config=SessionConfig.from_dict(config_data), + aliases=set(aliases_data), + ) + self._cache[key] = session + + # 建立别名映射 + for alias in session.aliases: + self._alias_map[alias] = key + + logger.info(f"已加载 {len(self._cache)} 个 session 配置") + except Exception as e: + logger.error(f"加载 session 失败: {e}") + + def _save_session(self, session: Session) -> None: + """保存 session 到数据库""" + try: + with sqlite3.connect(self.db_path) as conn: + conn.execute( + """ + INSERT OR REPLACE INTO sessions (key, config, aliases, created_at, updated_at) + VALUES (?, ?, ?, ?, ?) + """, + ( + session.key, + json.dumps(session.config.to_dict(), ensure_ascii=False), + json.dumps(list(session.aliases), ensure_ascii=False), + session.created_at.isoformat(), + session.updated_at.isoformat(), + ), + ) + conn.commit() + except Exception as e: + logger.error(f"保存 session 失败: {e}") + + def resolve_key(self, alias: str) -> str: + """解析别名到统一 key + + 如果别名已绑定到某个 session,返回该 session 的 key + 否则返回别名本身作为 key + """ + return self._alias_map.get(alias, alias) + + def get_or_create( + self, + key_or_alias: str, + max_history: int = 30, + load_history: bool = True, + ) -> Session: """获取或创建会话 Args: - key: 会话标识,格式为 "wechat:{chat_id}" + key_or_alias: 会话标识或别名(如 "wechat:wxid_xxx" 或 "user:john") max_history: 从数据库加载的最大历史消息数 + load_history: 是否从 MessageSummary 加载历史 Returns: Session 对象 """ + # 解析别名 + key = self.resolve_key(key_or_alias) + if key in self._cache: return self._cache[key] + # 创建新 session session = Session(key=key) + session.config.max_history = max_history - # 从 SQLite 加载历史 - chat_id = key.split(":", 1)[1] if ":" in key else key - if self.message_summary: + # 如果输入是别名且不等于 key,自动绑定 + if key_or_alias != key: + session.aliases.add(key_or_alias) + + # 从 MessageSummary 加载历史消息 + if load_history and self.message_summary: + chat_id = key.split(":", 1)[1] if ":" in key else key history = self.message_summary.get_messages(chat_id) for msg in history[-max_history:]: - role = ( - "assistant" - if msg.get("sender_wxid") == self.bot_wxid - else "user" - ) + role = "assistant" if msg.get("sender_wxid") == self.bot_id else "user" content = msg.get("content", "") if content: session.messages.append( @@ -114,14 +259,134 @@ class SessionManager: self._cache[key] = session return session - def get(self, key: str) -> Session | None: + def get(self, key_or_alias: str) -> Session | None: """获取已存在的会话""" + key = self.resolve_key(key_or_alias) return self._cache.get(key) - def remove(self, key: str) -> None: - """移除会话""" - self._cache.pop(key, None) + def bind(self, session_key: str, alias: str) -> Session: + """将别名绑定到指定 session + + Args: + session_key: 目标 session 的 key(如 "user:john") + alias: 要绑定的别名(如 "wechat:wxid_xxx") + + Returns: + 绑定后的 Session 对象 + """ + # 如果别名已绑定到其他 session,先解绑 + if alias in self._alias_map: + old_key = self._alias_map[alias] + if old_key != session_key and old_key in self._cache: + self._cache[old_key].unbind_alias(alias) + self._save_session(self._cache[old_key]) + + # 获取或创建目标 session + session = self.get_or_create(session_key, load_history=False) + session.bind_alias(alias) + self._alias_map[alias] = session_key + self._save_session(session) + + logger.info(f"已绑定 {alias} -> {session_key}") + return session + + def unbind(self, alias: str) -> bool: + """解除别名绑定 + + Returns: + 是否成功解绑 + """ + if alias not in self._alias_map: + return False + + key = self._alias_map.pop(alias) + if key in self._cache: + self._cache[key].unbind_alias(alias) + self._save_session(self._cache[key]) + + logger.info(f"已解绑 {alias}") + return True + + def set_config( + self, + key_or_alias: str, + model_id: int | None = None, + system_prompt: str | None = None, + persona: str | None = None, + max_history: int | None = None, + **extra, + ) -> Session: + """设置 session 配置 + + Args: + key_or_alias: 会话标识或别名 + model_id: 绑定的模型 ID + system_prompt: 自定义 system prompt + persona: 人设文本 + max_history: 历史消息限制 + **extra: 扩展配置 + + Returns: + 更新后的 Session 对象 + """ + session = self.get_or_create(key_or_alias, load_history=False) + + if model_id is not None: + session.config.model_id = model_id + if system_prompt is not None: + session.config.system_prompt = system_prompt + if persona is not None: + session.config.persona = persona + if max_history is not None: + session.config.max_history = max_history + if extra: + session.config.extra.update(extra) + + session.updated_at = datetime.now() + self._save_session(session) + return session + + def remove(self, key_or_alias: str) -> bool: + """移除会话 + + Returns: + 是否成功移除 + """ + key = self.resolve_key(key_or_alias) + if key not in self._cache: + return False + + session = self._cache.pop(key) + + # 清理别名映射 + for alias in session.aliases: + self._alias_map.pop(alias, None) + + # 从数据库删除 + try: + with sqlite3.connect(self.db_path) as conn: + conn.execute("DELETE FROM sessions WHERE key = ?", (key,)) + conn.commit() + except Exception as e: + logger.error(f"删除 session 失败: {e}") + + return True + + def list_sessions(self) -> list[dict]: + """列出所有 session 信息""" + result = [] + for key, session in self._cache.items(): + result.append({ + "key": key, + "aliases": list(session.aliases), + "model_id": session.config.model_id, + "max_history": session.config.max_history, + "message_count": len(session.messages), + "updated_at": session.updated_at.isoformat(), + }) + return result def clear_all(self) -> None: - """清空所有会话缓存""" + """清空所有会话缓存(不删除数据库)""" self._cache.clear() + self._alias_map.clear()