mirror of
https://github.com/Zippland/Bubbles.git
synced 2026-03-06 07:59:33 +08:00
feat(session): 增强会话管理器功能,支持跨渠道统一会话
- 新增 SessionConfig 类用于管理会话配置,包括模型绑定和人设设置 - 实现会话别名功能,支持跨渠道统一会话管理 - 增加 SQLite 持久化存储会话配置和历史消息 - 添加 /session 命令集,支持查看和修改会话配置 - 优化机器人初始化流程,支持会话配置优先于上下文设置
This commit is contained in:
143
bot.py
143
bot.py
@@ -89,8 +89,13 @@ class BubblesBot:
|
||||
# 初始化 Agent Loop 系统
|
||||
self.tool_registry = create_default_registry()
|
||||
self.agent_loop = AgentLoop(self.tool_registry, max_iterations=20)
|
||||
self.session_manager = SessionManager(self.message_summary, self.bot_id)
|
||||
self.session_manager = SessionManager(
|
||||
message_summary=self.message_summary,
|
||||
bot_id=self.bot_id,
|
||||
db_path=db_path,
|
||||
)
|
||||
self.LOG.info(f"Agent Loop 系统已初始化,工具: {self.tool_registry.get_tool_names()}")
|
||||
self.LOG.info(f"Session 管理器已初始化,已加载 {len(self.session_manager._cache)} 个 session")
|
||||
|
||||
def _init_chat_models(self) -> None:
|
||||
"""初始化所有 AI 模型"""
|
||||
@@ -243,9 +248,24 @@ class BubblesBot:
|
||||
chat_id = msg.get_chat_id()
|
||||
content = msg.content
|
||||
|
||||
# 获取会话
|
||||
session_key = f"{self.channel.name}:{chat_id}"
|
||||
session = self.session_manager.get_or_create(session_key, max_history=30)
|
||||
# 处理 session 命令
|
||||
if content.startswith("/session"):
|
||||
await self._handle_session_command(msg, content)
|
||||
return
|
||||
|
||||
# 获取会话(通过别名解析,支持跨 Channel 统一会话)
|
||||
session_alias = f"{self.channel.name}:{chat_id}"
|
||||
session = self.session_manager.get_or_create(session_alias)
|
||||
|
||||
# 从 session 配置获取设置
|
||||
session_config = session.config
|
||||
max_history = session_config.max_history or 30
|
||||
|
||||
# 选择模型(优先使用 session 绑定的模型)
|
||||
chat_model = self.chat
|
||||
if session_config.model_id is not None and session_config.model_id in self.chat_models:
|
||||
chat_model = self.chat_models[session_config.model_id]
|
||||
self.LOG.debug(f"使用 session 绑定的模型: {session_config.model_id}")
|
||||
|
||||
# 构建用户消息
|
||||
current_time = time_mod.strftime("%H:%M", time_mod.localtime())
|
||||
@@ -261,16 +281,19 @@ class BubblesBot:
|
||||
"- 日常闲聊、观点讨论、情感交流 → 直接回复,不需要调用任何工具\n"
|
||||
)
|
||||
|
||||
# 获取人设
|
||||
persona_text = None
|
||||
if self.persona_manager:
|
||||
# 获取人设(优先使用 session 绑定的,其次从 persona_manager)
|
||||
persona_text = session_config.persona
|
||||
if not persona_text and self.persona_manager:
|
||||
try:
|
||||
persona_text = self.persona_manager.get_persona(chat_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if persona_text:
|
||||
system_prompt = build_persona_system_prompt(self.chat, persona_text) + tool_guidance
|
||||
# 获取 system prompt(优先使用 session 绑定的)
|
||||
if session_config.system_prompt:
|
||||
system_prompt = session_config.system_prompt + tool_guidance
|
||||
elif persona_text:
|
||||
system_prompt = build_persona_system_prompt(chat_model, persona_text) + tool_guidance
|
||||
else:
|
||||
system_prompt = tool_guidance
|
||||
|
||||
@@ -283,7 +306,7 @@ class BubblesBot:
|
||||
messages.append({"role": "system", "content": f"Current time is: {now_time}"})
|
||||
|
||||
# 添加历史消息
|
||||
history = session.get_history(30)
|
||||
history = session.get_history(max_history)
|
||||
for hist_msg in history:
|
||||
role = hist_msg.get("role", "user")
|
||||
hist_content = hist_msg.get("content", "")
|
||||
@@ -320,14 +343,13 @@ class BubblesBot:
|
||||
robot=self,
|
||||
logger=self.LOG,
|
||||
config=self.config,
|
||||
specific_max_history=30,
|
||||
specific_max_history=max_history,
|
||||
persona=persona_text,
|
||||
_send_text_func=send_func,
|
||||
)
|
||||
|
||||
# 执行 Agent Loop
|
||||
try:
|
||||
chat_model = self.chat
|
||||
if not chat_model:
|
||||
await self.channel.send_text("抱歉,没有可用的 AI 模型。", chat_id)
|
||||
return
|
||||
@@ -358,6 +380,103 @@ class BubblesBot:
|
||||
self.LOG.error(f"Agent Loop 执行失败: {e}", exc_info=True)
|
||||
await self.channel.send_text("抱歉,处理消息时出错了。", chat_id)
|
||||
|
||||
async def _handle_session_command(self, msg: Message, content: str) -> None:
|
||||
"""处理 session 相关命令
|
||||
|
||||
命令格式:
|
||||
/session bind <session_key> - 绑定当前会话到指定 session
|
||||
/session unbind - 解除当前会话的绑定
|
||||
/session info - 查看当前 session 信息
|
||||
/session list - 列出所有 session
|
||||
/session model <model_id> - 设置当前 session 的模型
|
||||
/session clear - 清空当前 session 的消息历史
|
||||
"""
|
||||
chat_id = msg.get_chat_id()
|
||||
alias = f"{self.channel.name}:{chat_id}"
|
||||
parts = content.split()
|
||||
|
||||
if len(parts) < 2:
|
||||
await self.channel.send_text(
|
||||
"用法:\n"
|
||||
" /session bind <key> - 绑定到 session\n"
|
||||
" /session unbind - 解除绑定\n"
|
||||
" /session info - 查看信息\n"
|
||||
" /session list - 列出所有\n"
|
||||
" /session model <id> - 设置模型\n"
|
||||
" /session clear - 清空历史",
|
||||
chat_id,
|
||||
)
|
||||
return
|
||||
|
||||
cmd = parts[1].lower()
|
||||
|
||||
if cmd == "bind" and len(parts) >= 3:
|
||||
session_key = parts[2]
|
||||
session = self.session_manager.bind(session_key, alias)
|
||||
await self.channel.send_text(
|
||||
f"已绑定到 session: {session_key}\n"
|
||||
f"别名: {', '.join(session.aliases)}",
|
||||
chat_id,
|
||||
)
|
||||
|
||||
elif cmd == "unbind":
|
||||
if self.session_manager.unbind(alias):
|
||||
await self.channel.send_text("已解除绑定", chat_id)
|
||||
else:
|
||||
await self.channel.send_text("当前会话未绑定到任何 session", chat_id)
|
||||
|
||||
elif cmd == "info":
|
||||
session = self.session_manager.get(alias)
|
||||
if session:
|
||||
model_name = "默认"
|
||||
if session.config.model_id is not None:
|
||||
model = self.chat_models.get(session.config.model_id)
|
||||
model_name = model.__class__.__name__ if model else f"ID:{session.config.model_id}"
|
||||
|
||||
info = (
|
||||
f"Session Key: {session.key}\n"
|
||||
f"别名: {', '.join(session.aliases) or '无'}\n"
|
||||
f"模型: {model_name}\n"
|
||||
f"历史限制: {session.config.max_history}\n"
|
||||
f"消息数: {len(session.messages)}\n"
|
||||
f"人设: {'已设置' if session.config.persona else '未设置'}"
|
||||
)
|
||||
await self.channel.send_text(info, chat_id)
|
||||
else:
|
||||
await self.channel.send_text("当前会话未创建 session", chat_id)
|
||||
|
||||
elif cmd == "list":
|
||||
sessions = self.session_manager.list_sessions()
|
||||
if sessions:
|
||||
lines = ["所有 Session:"]
|
||||
for s in sessions:
|
||||
aliases = ", ".join(s["aliases"]) if s["aliases"] else "无别名"
|
||||
lines.append(f" {s['key']} ({aliases}) - {s['message_count']} 条消息")
|
||||
await self.channel.send_text("\n".join(lines), chat_id)
|
||||
else:
|
||||
await self.channel.send_text("暂无 session", chat_id)
|
||||
|
||||
elif cmd == "model" and len(parts) >= 3:
|
||||
try:
|
||||
model_id = int(parts[2])
|
||||
if model_id in self.chat_models:
|
||||
self.session_manager.set_config(alias, model_id=model_id)
|
||||
model_name = self.chat_models[model_id].__class__.__name__
|
||||
await self.channel.send_text(f"已设置模型: {model_name}", chat_id)
|
||||
else:
|
||||
available = ", ".join(str(k) for k in self.chat_models.keys())
|
||||
await self.channel.send_text(f"无效的模型 ID,可用: {available}", chat_id)
|
||||
except ValueError:
|
||||
await self.channel.send_text("模型 ID 必须是数字", chat_id)
|
||||
|
||||
elif cmd == "clear":
|
||||
session = self.session_manager.get_or_create(alias)
|
||||
session.clear()
|
||||
await self.channel.send_text("已清空消息历史", chat_id)
|
||||
|
||||
else:
|
||||
await self.channel.send_text(f"未知命令: {cmd}", chat_id)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""清理资源"""
|
||||
self.LOG.info("正在清理 BubblesBot 资源...")
|
||||
|
||||
43
robot.py
43
robot.py
@@ -297,8 +297,13 @@ class Robot(Job):
|
||||
# 初始化 Agent Loop 系统
|
||||
self.tool_registry = create_default_registry()
|
||||
self.agent_loop = AgentLoop(self.tool_registry, max_iterations=20)
|
||||
self.session_manager = SessionManager(self.message_summary, self.wxid)
|
||||
self.session_manager = SessionManager(
|
||||
message_summary=self.message_summary,
|
||||
bot_id=self.wxid,
|
||||
db_path=db_path,
|
||||
)
|
||||
self.LOG.info(f"Agent Loop 系统已初始化,工具: {self.tool_registry.get_tool_names()}")
|
||||
self.LOG.info(f"Session 管理器已初始化,已加载 {len(self.session_manager._cache)} 个 session")
|
||||
|
||||
@staticmethod
|
||||
def value_check(args: dict) -> bool:
|
||||
@@ -659,8 +664,23 @@ class Robot(Job):
|
||||
reasoning_requested = bool(getattr(ctx, 'reasoning_requested', False)) or force_reasoning
|
||||
is_auto_random_reply = bool(getattr(ctx, 'auto_random_reply', False))
|
||||
|
||||
# 选择模型
|
||||
# 获取或创建会话(通过别名解析,支持跨 Channel 统一会话)
|
||||
chat_id = ctx.get_receiver()
|
||||
session_alias = f"wechat:{chat_id}"
|
||||
specific_max_history = getattr(ctx, 'specific_max_history', 30) or 30
|
||||
session = self.session_manager.get_or_create(session_alias, max_history=specific_max_history)
|
||||
|
||||
# 从 session 配置获取设置
|
||||
session_config = session.config
|
||||
if session_config.max_history:
|
||||
specific_max_history = session_config.max_history
|
||||
|
||||
# 选择模型(优先使用 session 绑定的模型,其次是上下文指定的)
|
||||
chat_model = getattr(ctx, 'chat', None) or self.chat
|
||||
if session_config.model_id is not None and session_config.model_id in self.chat_models:
|
||||
chat_model = self.chat_models[session_config.model_id]
|
||||
self.LOG.debug(f"使用 session 绑定的模型: {session_config.model_id}")
|
||||
|
||||
if reasoning_requested:
|
||||
if force_reasoning:
|
||||
self.LOG.info("群配置了 force_reasoning,将使用推理模型。")
|
||||
@@ -678,16 +698,6 @@ class Robot(Job):
|
||||
await self.send_text_async("抱歉,我现在无法进行对话。", ctx.get_receiver())
|
||||
return False
|
||||
|
||||
# 获取历史消息限制
|
||||
specific_max_history = getattr(ctx, 'specific_max_history', 30)
|
||||
if specific_max_history is None:
|
||||
specific_max_history = 30
|
||||
|
||||
# 获取或创建会话
|
||||
chat_id = ctx.get_receiver()
|
||||
session_key = f"wechat:{chat_id}"
|
||||
session = self.session_manager.get_or_create(session_key, max_history=specific_max_history)
|
||||
|
||||
# 构建用户消息
|
||||
sender_name = ctx.sender_name
|
||||
content = ctx.text
|
||||
@@ -723,8 +733,8 @@ class Robot(Job):
|
||||
"请只针对该用户进行回复。"
|
||||
)
|
||||
|
||||
# 构建系统提示
|
||||
persona_text = getattr(ctx, 'persona', None)
|
||||
# 构建系统提示(优先使用 session 配置)
|
||||
persona_text = session_config.persona or getattr(ctx, 'persona', None)
|
||||
tool_guidance = ""
|
||||
if not is_auto_random_reply:
|
||||
tool_guidance = (
|
||||
@@ -737,7 +747,10 @@ class Robot(Job):
|
||||
"你可以在一次对话中多次调用工具。"
|
||||
)
|
||||
|
||||
if persona_text:
|
||||
# 优先使用 session 绑定的 system_prompt
|
||||
if session_config.system_prompt:
|
||||
system_prompt = session_config.system_prompt + tool_guidance
|
||||
elif persona_text:
|
||||
try:
|
||||
base_prompt = build_persona_system_prompt(chat_model, persona_text)
|
||||
system_prompt = base_prompt + tool_guidance if base_prompt else tool_guidance or None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# session/__init__.py
|
||||
from .manager import Session, SessionManager
|
||||
from .manager import Session, SessionManager, SessionConfig
|
||||
|
||||
__all__ = ["Session", "SessionManager"]
|
||||
__all__ = ["Session", "SessionManager", "SessionConfig"]
|
||||
|
||||
@@ -1,23 +1,59 @@
|
||||
# session/manager.py
|
||||
"""会话管理器"""
|
||||
"""增强版会话管理器 - 支持跨 Channel 统一会话"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import sqlite3
|
||||
import logging
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from function.func_summary import MessageSummary
|
||||
|
||||
logger = logging.getLogger("SessionManager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionConfig:
|
||||
"""Session 配置 - 绑定到会话的设置"""
|
||||
|
||||
model_id: int | None = None # 绑定的模型 ID
|
||||
system_prompt: str | None = None # 自定义 system prompt
|
||||
persona: str | None = None # 人设文本
|
||||
max_history: int = 30 # 历史消息限制
|
||||
extra: dict = field(default_factory=dict) # 扩展配置
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"model_id": self.model_id,
|
||||
"system_prompt": self.system_prompt,
|
||||
"persona": self.persona,
|
||||
"max_history": self.max_history,
|
||||
"extra": self.extra,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "SessionConfig":
|
||||
return cls(
|
||||
model_id=data.get("model_id"),
|
||||
system_prompt=data.get("system_prompt"),
|
||||
persona=data.get("persona"),
|
||||
max_history=data.get("max_history", 30),
|
||||
extra=data.get("extra", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""会话对象 - 管理单个对话的状态"""
|
||||
"""增强版会话对象 - 管理单个对话的完整状态"""
|
||||
|
||||
key: str # "wechat:{chat_id}"
|
||||
key: str # 统一会话 key (如 "user:john" 或 "group:test")
|
||||
config: SessionConfig = field(default_factory=SessionConfig)
|
||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
aliases: set[str] = field(default_factory=set) # 别名集合 (如 {"wechat:wxid_xxx", "local:john"})
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def add_message(self, role: str, content: str, **kwargs) -> None:
|
||||
"""添加消息"""
|
||||
@@ -57,49 +93,158 @@ class Session:
|
||||
)
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def get_history(self, max_messages: int = 30) -> list[dict]:
|
||||
def get_history(self, max_messages: int | None = None) -> list[dict]:
|
||||
"""获取最近的消息历史"""
|
||||
return self.messages[-max_messages:]
|
||||
limit = max_messages or self.config.max_history
|
||||
return self.messages[-limit:]
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空消息"""
|
||||
self.messages.clear()
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def bind_alias(self, alias: str) -> None:
|
||||
"""绑定别名"""
|
||||
self.aliases.add(alias)
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def unbind_alias(self, alias: str) -> None:
|
||||
"""解绑别名"""
|
||||
self.aliases.discard(alias)
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器 - 与 MessageSummary 协作加载历史"""
|
||||
"""增强版会话管理器
|
||||
|
||||
def __init__(self, message_summary: "MessageSummary", bot_wxid: str):
|
||||
支持:
|
||||
- 跨 Channel 统一会话(通过别名映射)
|
||||
- Session 配置持久化
|
||||
- 绑定模型/人设/system prompt
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_summary: "MessageSummary | None" = None,
|
||||
bot_id: str = "",
|
||||
db_path: str = "data/message_history.db",
|
||||
):
|
||||
self.message_summary = message_summary
|
||||
self.bot_wxid = bot_wxid
|
||||
self._cache: dict[str, Session] = {}
|
||||
self.bot_id = bot_id
|
||||
self.db_path = db_path
|
||||
self._cache: dict[str, Session] = {} # key -> Session
|
||||
self._alias_map: dict[str, str] = {} # alias -> key
|
||||
self._init_db()
|
||||
self._load_sessions()
|
||||
|
||||
def get_or_create(self, key: str, max_history: int = 100) -> Session:
|
||||
def _init_db(self) -> None:
|
||||
"""初始化数据库表"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
key TEXT PRIMARY KEY,
|
||||
config TEXT NOT NULL DEFAULT '{}',
|
||||
aliases TEXT NOT NULL DEFAULT '[]',
|
||||
created_at TEXT,
|
||||
updated_at TEXT
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 session 表失败: {e}")
|
||||
|
||||
def _load_sessions(self) -> None:
|
||||
"""从数据库加载所有 session 配置"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute("SELECT * FROM sessions")
|
||||
for row in cursor:
|
||||
key = row["key"]
|
||||
config_data = json.loads(row["config"] or "{}")
|
||||
aliases_data = json.loads(row["aliases"] or "[]")
|
||||
|
||||
session = Session(
|
||||
key=key,
|
||||
config=SessionConfig.from_dict(config_data),
|
||||
aliases=set(aliases_data),
|
||||
)
|
||||
self._cache[key] = session
|
||||
|
||||
# 建立别名映射
|
||||
for alias in session.aliases:
|
||||
self._alias_map[alias] = key
|
||||
|
||||
logger.info(f"已加载 {len(self._cache)} 个 session 配置")
|
||||
except Exception as e:
|
||||
logger.error(f"加载 session 失败: {e}")
|
||||
|
||||
def _save_session(self, session: Session) -> None:
|
||||
"""保存 session 到数据库"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO sessions (key, config, aliases, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
session.key,
|
||||
json.dumps(session.config.to_dict(), ensure_ascii=False),
|
||||
json.dumps(list(session.aliases), ensure_ascii=False),
|
||||
session.created_at.isoformat(),
|
||||
session.updated_at.isoformat(),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"保存 session 失败: {e}")
|
||||
|
||||
def resolve_key(self, alias: str) -> str:
|
||||
"""解析别名到统一 key
|
||||
|
||||
如果别名已绑定到某个 session,返回该 session 的 key
|
||||
否则返回别名本身作为 key
|
||||
"""
|
||||
return self._alias_map.get(alias, alias)
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
key_or_alias: str,
|
||||
max_history: int = 30,
|
||||
load_history: bool = True,
|
||||
) -> Session:
|
||||
"""获取或创建会话
|
||||
|
||||
Args:
|
||||
key: 会话标识,格式为 "wechat:{chat_id}"
|
||||
key_or_alias: 会话标识或别名(如 "wechat:wxid_xxx" 或 "user:john")
|
||||
max_history: 从数据库加载的最大历史消息数
|
||||
load_history: 是否从 MessageSummary 加载历史
|
||||
|
||||
Returns:
|
||||
Session 对象
|
||||
"""
|
||||
# 解析别名
|
||||
key = self.resolve_key(key_or_alias)
|
||||
|
||||
if key in self._cache:
|
||||
return self._cache[key]
|
||||
|
||||
# 创建新 session
|
||||
session = Session(key=key)
|
||||
session.config.max_history = max_history
|
||||
|
||||
# 从 SQLite 加载历史
|
||||
chat_id = key.split(":", 1)[1] if ":" in key else key
|
||||
if self.message_summary:
|
||||
# 如果输入是别名且不等于 key,自动绑定
|
||||
if key_or_alias != key:
|
||||
session.aliases.add(key_or_alias)
|
||||
|
||||
# 从 MessageSummary 加载历史消息
|
||||
if load_history and self.message_summary:
|
||||
chat_id = key.split(":", 1)[1] if ":" in key else key
|
||||
history = self.message_summary.get_messages(chat_id)
|
||||
for msg in history[-max_history:]:
|
||||
role = (
|
||||
"assistant"
|
||||
if msg.get("sender_wxid") == self.bot_wxid
|
||||
else "user"
|
||||
)
|
||||
role = "assistant" if msg.get("sender_wxid") == self.bot_id else "user"
|
||||
content = msg.get("content", "")
|
||||
if content:
|
||||
session.messages.append(
|
||||
@@ -114,14 +259,134 @@ class SessionManager:
|
||||
self._cache[key] = session
|
||||
return session
|
||||
|
||||
def get(self, key: str) -> Session | None:
|
||||
def get(self, key_or_alias: str) -> Session | None:
|
||||
"""获取已存在的会话"""
|
||||
key = self.resolve_key(key_or_alias)
|
||||
return self._cache.get(key)
|
||||
|
||||
def remove(self, key: str) -> None:
|
||||
"""移除会话"""
|
||||
self._cache.pop(key, None)
|
||||
def bind(self, session_key: str, alias: str) -> Session:
|
||||
"""将别名绑定到指定 session
|
||||
|
||||
Args:
|
||||
session_key: 目标 session 的 key(如 "user:john")
|
||||
alias: 要绑定的别名(如 "wechat:wxid_xxx")
|
||||
|
||||
Returns:
|
||||
绑定后的 Session 对象
|
||||
"""
|
||||
# 如果别名已绑定到其他 session,先解绑
|
||||
if alias in self._alias_map:
|
||||
old_key = self._alias_map[alias]
|
||||
if old_key != session_key and old_key in self._cache:
|
||||
self._cache[old_key].unbind_alias(alias)
|
||||
self._save_session(self._cache[old_key])
|
||||
|
||||
# 获取或创建目标 session
|
||||
session = self.get_or_create(session_key, load_history=False)
|
||||
session.bind_alias(alias)
|
||||
self._alias_map[alias] = session_key
|
||||
self._save_session(session)
|
||||
|
||||
logger.info(f"已绑定 {alias} -> {session_key}")
|
||||
return session
|
||||
|
||||
def unbind(self, alias: str) -> bool:
|
||||
"""解除别名绑定
|
||||
|
||||
Returns:
|
||||
是否成功解绑
|
||||
"""
|
||||
if alias not in self._alias_map:
|
||||
return False
|
||||
|
||||
key = self._alias_map.pop(alias)
|
||||
if key in self._cache:
|
||||
self._cache[key].unbind_alias(alias)
|
||||
self._save_session(self._cache[key])
|
||||
|
||||
logger.info(f"已解绑 {alias}")
|
||||
return True
|
||||
|
||||
def set_config(
|
||||
self,
|
||||
key_or_alias: str,
|
||||
model_id: int | None = None,
|
||||
system_prompt: str | None = None,
|
||||
persona: str | None = None,
|
||||
max_history: int | None = None,
|
||||
**extra,
|
||||
) -> Session:
|
||||
"""设置 session 配置
|
||||
|
||||
Args:
|
||||
key_or_alias: 会话标识或别名
|
||||
model_id: 绑定的模型 ID
|
||||
system_prompt: 自定义 system prompt
|
||||
persona: 人设文本
|
||||
max_history: 历史消息限制
|
||||
**extra: 扩展配置
|
||||
|
||||
Returns:
|
||||
更新后的 Session 对象
|
||||
"""
|
||||
session = self.get_or_create(key_or_alias, load_history=False)
|
||||
|
||||
if model_id is not None:
|
||||
session.config.model_id = model_id
|
||||
if system_prompt is not None:
|
||||
session.config.system_prompt = system_prompt
|
||||
if persona is not None:
|
||||
session.config.persona = persona
|
||||
if max_history is not None:
|
||||
session.config.max_history = max_history
|
||||
if extra:
|
||||
session.config.extra.update(extra)
|
||||
|
||||
session.updated_at = datetime.now()
|
||||
self._save_session(session)
|
||||
return session
|
||||
|
||||
def remove(self, key_or_alias: str) -> bool:
|
||||
"""移除会话
|
||||
|
||||
Returns:
|
||||
是否成功移除
|
||||
"""
|
||||
key = self.resolve_key(key_or_alias)
|
||||
if key not in self._cache:
|
||||
return False
|
||||
|
||||
session = self._cache.pop(key)
|
||||
|
||||
# 清理别名映射
|
||||
for alias in session.aliases:
|
||||
self._alias_map.pop(alias, None)
|
||||
|
||||
# 从数据库删除
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("DELETE FROM sessions WHERE key = ?", (key,))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"删除 session 失败: {e}")
|
||||
|
||||
return True
|
||||
|
||||
def list_sessions(self) -> list[dict]:
|
||||
"""列出所有 session 信息"""
|
||||
result = []
|
||||
for key, session in self._cache.items():
|
||||
result.append({
|
||||
"key": key,
|
||||
"aliases": list(session.aliases),
|
||||
"model_id": session.config.model_id,
|
||||
"max_history": session.config.max_history,
|
||||
"message_count": len(session.messages),
|
||||
"updated_at": session.updated_at.isoformat(),
|
||||
})
|
||||
return result
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""清空所有会话缓存"""
|
||||
"""清空所有会话缓存(不删除数据库)"""
|
||||
self._cache.clear()
|
||||
self._alias_map.clear()
|
||||
|
||||
Reference in New Issue
Block a user