mirror of
https://github.com/Zippland/Bubbles.git
synced 2026-01-19 01:21:15 +08:00
重构 ChatGPT 和 DeepSeek 类的初始化逻辑,新增最大历史消息数配置,优化消息处理逻辑,移除冗余代码,提升代码可读性。同时更新消息记录逻辑,确保消息记录功能正常。
This commit is contained in:
@@ -19,39 +19,30 @@ except ImportError:
|
||||
|
||||
|
||||
class ChatGPT():
|
||||
# ---- 修改 __init__ ----
|
||||
def __init__(self, conf: dict, message_summary_instance: MessageSummary = None, bot_wxid: str = None) -> None:
|
||||
key = conf.get("key")
|
||||
api = conf.get("api")
|
||||
proxy = conf.get("proxy")
|
||||
prompt = conf.get("prompt")
|
||||
self.model = conf.get("model", "gpt-3.5-turbo")
|
||||
# ---- 新增:读取最大历史消息数配置 ----
|
||||
self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置,默认10条
|
||||
# ---- 新增结束 ----
|
||||
self.LOG = logging.getLogger("ChatGPT")
|
||||
|
||||
# ---- 存储传入的实例和wxid ----
|
||||
# 存储传入的实例和wxid
|
||||
self.message_summary = message_summary_instance
|
||||
self.bot_wxid = bot_wxid
|
||||
if not self.message_summary:
|
||||
self.LOG.warning("MessageSummary 实例未提供给 ChatGPT,上下文功能将不可用!")
|
||||
if not self.bot_wxid:
|
||||
self.LOG.warning("bot_wxid 未提供给 ChatGPT,可能无法正确识别机器人自身消息!")
|
||||
# ---- 存储结束 ----
|
||||
|
||||
if proxy:
|
||||
self.client = OpenAI(api_key=key, base_url=api, http_client=httpx.Client(proxy=proxy))
|
||||
else:
|
||||
self.client = OpenAI(api_key=key, base_url=api)
|
||||
|
||||
# ---- 移除 self.conversation_list ----
|
||||
# self.conversation_list = {}
|
||||
# ---- 移除结束 ----
|
||||
|
||||
self.system_content_msg = {"role": "system", "content": prompt if prompt else "You are a helpful assistant."} # 提供默认值
|
||||
self.support_vision = self.model == "gpt-4-vision-preview" or self.model == "gpt-4o" or "-vision" in self.model
|
||||
# ---- __init__ 修改结束 ----
|
||||
|
||||
def __repr__(self):
|
||||
return 'ChatGPT'
|
||||
@@ -60,18 +51,13 @@ class ChatGPT():
|
||||
def value_check(conf: dict) -> bool:
|
||||
# 不再检查 prompt,因为可以没有默认 prompt
|
||||
if conf:
|
||||
# ---- 修改:也检查 max_history_messages (虽然有默认值) ----
|
||||
# 也检查 max_history_messages (虽然有默认值)
|
||||
if conf.get("key") and conf.get("api"): # and conf.get("max_history_messages") is not None: # 如果需要强制配置
|
||||
return True
|
||||
return False
|
||||
|
||||
# ---- 修改 get_answer ----
|
||||
def get_answer(self, question: str, wxid: str, system_prompt_override=None) -> str:
|
||||
# ---- 移除 #清除对话 逻辑 ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
# ---- 获取并格式化数据库历史记录 ----
|
||||
# 获取并格式化数据库历史记录
|
||||
api_messages = []
|
||||
|
||||
# 1. 添加系统提示
|
||||
@@ -89,12 +75,11 @@ class ChatGPT():
|
||||
if self.message_summary and self.bot_wxid:
|
||||
history = self.message_summary.get_messages(wxid)
|
||||
|
||||
# ---- 新增:限制历史消息数量 ----
|
||||
# -限制历史消息数量
|
||||
if self.max_history_messages is not None and self.max_history_messages > 0:
|
||||
history = history[-self.max_history_messages:] # 取最新的 N 条
|
||||
elif self.max_history_messages == 0: # 如果设置为0,则不包含历史
|
||||
history = []
|
||||
# ---- 新增结束 ----
|
||||
|
||||
for msg in history:
|
||||
role = "assistant" if msg.get("sender_wxid") == self.bot_wxid else "user"
|
||||
@@ -107,11 +92,10 @@ class ChatGPT():
|
||||
# 3. 添加当前用户问题
|
||||
if question: # 确保问题非空
|
||||
api_messages.append({"role": "user", "content": question})
|
||||
# ---- 获取和格式化结束 ----
|
||||
|
||||
rsp = ""
|
||||
try:
|
||||
# ---- 使用格式化后的 api_messages ----
|
||||
# 使用格式化后的 api_messages
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": api_messages # 使用从数据库构建的消息列表
|
||||
@@ -126,10 +110,6 @@ class ChatGPT():
|
||||
rsp = rsp[2:] if rsp.startswith("\n\n") else rsp
|
||||
rsp = rsp.replace("\n\n", "\n")
|
||||
|
||||
# ---- 移除 updateMessage 调用 ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
except AuthenticationError:
|
||||
self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确")
|
||||
rsp = "API认证失败"
|
||||
@@ -144,7 +124,6 @@ class ChatGPT():
|
||||
rsp = "发生未知错误"
|
||||
|
||||
return rsp
|
||||
# ---- get_answer 修改结束 ----
|
||||
|
||||
def encode_image_to_base64(self, image_path: str) -> str:
|
||||
"""将图片文件转换为Base64编码
|
||||
@@ -231,10 +210,6 @@ class ChatGPT():
|
||||
self.LOG.error(f"分析图片时发生未知错误:{str(e0)}")
|
||||
return f"处理图片时出错:{str(e0)}"
|
||||
|
||||
# ---- 移除 updateMessage ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# --- 测试代码需要调整 ---
|
||||
|
||||
@@ -16,42 +16,30 @@ except ImportError:
|
||||
MessageSummary = object
|
||||
|
||||
class DeepSeek():
|
||||
# ---- 修改 __init__ ----
|
||||
def __init__(self, conf: dict, message_summary_instance: MessageSummary = None, bot_wxid: str = None) -> None:
|
||||
key = conf.get("key")
|
||||
api = conf.get("api", "https://api.deepseek.com")
|
||||
proxy = conf.get("proxy")
|
||||
prompt = conf.get("prompt")
|
||||
self.model = conf.get("model", "deepseek-chat")
|
||||
# ---- 新增:读取最大历史消息数配置 ----
|
||||
# 读取最大历史消息数配置
|
||||
self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置,默认10条
|
||||
# ---- 新增结束 ----
|
||||
self.LOG = logging.getLogger("DeepSeek")
|
||||
|
||||
# ---- 存储传入的实例和wxid ----
|
||||
# 存储传入的实例和wxid
|
||||
self.message_summary = message_summary_instance
|
||||
self.bot_wxid = bot_wxid
|
||||
if not self.message_summary:
|
||||
self.LOG.warning("MessageSummary 实例未提供给 DeepSeek,上下文功能将不可用!")
|
||||
if not self.bot_wxid:
|
||||
self.LOG.warning("bot_wxid 未提供给 DeepSeek,可能无法正确识别机器人自身消息!")
|
||||
# ---- 存储结束 ----
|
||||
|
||||
# ---- 移除思维链相关逻辑 ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
if proxy:
|
||||
self.client = OpenAI(api_key=key, base_url=api, http_client=httpx.Client(proxy=proxy))
|
||||
else:
|
||||
self.client = OpenAI(api_key=key, base_url=api)
|
||||
|
||||
# ---- 移除 self.conversation_list ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
self.system_content_msg = {"role": "system", "content": prompt if prompt else "You are a helpful assistant."} # 提供默认值
|
||||
# ---- __init__ 修改结束 ----
|
||||
|
||||
def __repr__(self):
|
||||
return 'DeepSeek'
|
||||
@@ -59,18 +47,13 @@ class DeepSeek():
|
||||
@staticmethod
|
||||
def value_check(conf: dict) -> bool:
|
||||
if conf:
|
||||
# ---- 修改:也检查 max_history_messages (虽然有默认值) ----
|
||||
# 也检查 max_history_messages (虽然有默认值)
|
||||
if conf.get("key"): # and conf.get("max_history_messages") is not None: # 如果需要强制配置
|
||||
return True
|
||||
return False
|
||||
|
||||
# ---- 修改 get_answer ----
|
||||
def get_answer(self, question: str, wxid: str, system_prompt_override=None) -> str:
|
||||
# ---- 移除 #清除对话 和 思维链命令 ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
# ---- 获取并格式化数据库历史记录 ----
|
||||
# 获取并格式化数据库历史记录
|
||||
api_messages = []
|
||||
|
||||
# 1. 添加系统提示
|
||||
@@ -88,12 +71,11 @@ class DeepSeek():
|
||||
if self.message_summary and self.bot_wxid:
|
||||
history = self.message_summary.get_messages(wxid)
|
||||
|
||||
# ---- 新增:限制历史消息数量 ----
|
||||
# 限制历史消息数量
|
||||
if self.max_history_messages is not None and self.max_history_messages > 0:
|
||||
history = history[-self.max_history_messages:] # 取最新的 N 条
|
||||
elif self.max_history_messages == 0: # 如果设置为0,则不包含历史
|
||||
history = []
|
||||
# ---- 新增结束 ----
|
||||
|
||||
for msg in history:
|
||||
role = "assistant" if msg.get("sender_wxid") == self.bot_wxid else "user"
|
||||
@@ -106,21 +88,16 @@ class DeepSeek():
|
||||
# 3. 添加当前用户问题
|
||||
if question:
|
||||
api_messages.append({"role": "user", "content": question})
|
||||
# ---- 获取和格式化结束 ----
|
||||
|
||||
try:
|
||||
# ---- 使用格式化后的 api_messages ----
|
||||
# 使用格式化后的 api_messages
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages, # 使用构建的消息列表
|
||||
stream=False
|
||||
)
|
||||
|
||||
# ---- 移除思维链特殊处理和本地历史更新 ----
|
||||
# ... (代码已移除) ...
|
||||
final_response = response.choices[0].message.content
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
|
||||
return final_response
|
||||
|
||||
@@ -130,7 +107,6 @@ class DeepSeek():
|
||||
except Exception as e0:
|
||||
self.LOG.error(f"发生未知错误:{str(e0)}")
|
||||
return "抱歉,处理您的请求时出现了错误"
|
||||
# ---- get_answer 修改结束 ----
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -268,7 +268,7 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool:
|
||||
ctx.send_text("抱歉,我现在无法进行对话。")
|
||||
return False
|
||||
|
||||
# ---- 处理引用图片情况 ----
|
||||
# 处理引用图片情况
|
||||
if getattr(ctx, 'is_quoted_image', False):
|
||||
ctx.logger.info("检测到引用图片消息,尝试处理图片内容...")
|
||||
|
||||
@@ -342,7 +342,6 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool:
|
||||
ctx.logger.error(f"处理引用图片过程中出错: {e}")
|
||||
ctx.send_text(f"处理图片时发生错误: {str(e)}")
|
||||
return True # 已处理,即使出错也不执行后续普通文本处理
|
||||
# ---- 引用图片处理结束 ----
|
||||
|
||||
# 获取消息内容
|
||||
content = ctx.text
|
||||
@@ -684,7 +683,7 @@ def handle_reminder(ctx: 'MessageContext', match: Optional[Match]) -> bool:
|
||||
ctx.send_text("🤔 嗯... 我好像没太明白您想设置什么提醒,可以换种方式再说一次吗?", at_list)
|
||||
return True
|
||||
|
||||
# ---- 批量处理提醒 ----
|
||||
# 批量处理提醒
|
||||
results = [] # 用于存储每个提醒的处理结果
|
||||
roomid = ctx.msg.roomid if ctx.is_group else None
|
||||
|
||||
@@ -739,7 +738,7 @@ def handle_reminder(ctx: 'MessageContext', match: Optional[Match]) -> bool:
|
||||
results.append({"label": reminder_label, "success": False, "error": validation_error, "data": data})
|
||||
if ctx.logger: ctx.logger.warning(f"提醒数据验证失败 ({reminder_label}): {validation_error} - Data: {data}")
|
||||
|
||||
# ---- 构建汇总反馈消息 ----
|
||||
# 构建汇总反馈消息
|
||||
reply_parts = []
|
||||
successful_count = sum(1 for res in results if res["success"])
|
||||
failed_count = len(results) - successful_count
|
||||
|
||||
@@ -39,7 +39,6 @@ class MessageSummary:
|
||||
self.cursor = self.conn.cursor()
|
||||
self.LOG.info(f"已连接到 SQLite 数据库: {self.db_path}")
|
||||
|
||||
# ---- 修改数据库表结构 ----
|
||||
# 检查并添加 sender_wxid 列 (如果不存在)
|
||||
self.cursor.execute("PRAGMA table_info(messages)")
|
||||
columns = [col[1] for col in self.cursor.fetchall()]
|
||||
@@ -67,7 +66,6 @@ class MessageSummary:
|
||||
timestamp_str TEXT NOT NULL -- 存储完整时间格式 YYYY-MM-DD HH:MM:SS
|
||||
)
|
||||
""")
|
||||
# ---- 数据库表结构修改结束 ----
|
||||
|
||||
self.cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_time ON messages (chat_id, timestamp_float)
|
||||
@@ -96,7 +94,6 @@ class MessageSummary:
|
||||
except sqlite3.Error as e:
|
||||
self.LOG.error(f"关闭数据库连接时出错: {e}")
|
||||
|
||||
# ---- 修改 record_message ----
|
||||
def record_message(self, chat_id, sender_name, sender_wxid, content, timestamp=None):
|
||||
"""记录单条消息到数据库
|
||||
|
||||
@@ -110,7 +107,6 @@ class MessageSummary:
|
||||
try:
|
||||
current_time_float = time.time()
|
||||
|
||||
# ---- 修改时间格式 ----
|
||||
if not timestamp:
|
||||
# 默认使用完整时间格式
|
||||
timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(current_time_float))
|
||||
@@ -132,7 +128,6 @@ class MessageSummary:
|
||||
INSERT INTO messages (chat_id, sender, sender_wxid, content, timestamp_float, timestamp_str)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (chat_id, sender_name, sender_wxid, content, current_time_float, timestamp_str))
|
||||
# ---- 时间格式和插入修改结束 ----
|
||||
|
||||
# 删除超出 max_history 的旧消息
|
||||
self.cursor.execute("""
|
||||
@@ -154,7 +149,6 @@ class MessageSummary:
|
||||
self.conn.rollback()
|
||||
except:
|
||||
pass
|
||||
# ---- record_message 修改结束 ----
|
||||
|
||||
def clear_message_history(self, chat_id):
|
||||
"""清除指定聊天的消息历史记录
|
||||
@@ -194,7 +188,6 @@ class MessageSummary:
|
||||
self.LOG.error(f"获取消息数量时出错 (chat_id={chat_id}): {e}")
|
||||
return 0
|
||||
|
||||
# ---- 修改 get_messages ----
|
||||
def get_messages(self, chat_id):
|
||||
"""获取指定聊天的所有消息 (按时间升序),包含发送者wxid和完整时间戳
|
||||
|
||||
@@ -230,7 +223,6 @@ class MessageSummary:
|
||||
self.LOG.error(f"获取消息列表时出错 (chat_id={chat_id}): {e}")
|
||||
|
||||
return messages
|
||||
# ---- get_messages 修改结束 ----
|
||||
|
||||
def _basic_summarize(self, messages):
|
||||
"""基本的消息总结逻辑,不使用AI
|
||||
@@ -322,7 +314,6 @@ class MessageSummary:
|
||||
else:
|
||||
return self._basic_summarize(messages)
|
||||
|
||||
# ---- 修改 process_message_from_wxmsg ----
|
||||
def process_message_from_wxmsg(self, msg, wcf, all_contacts, bot_wxid=None):
|
||||
"""从微信消息对象中处理并记录与总结相关的文本消息
|
||||
记录所有群聊和私聊的文本(1)和App/卡片(49)消息。
|
||||
@@ -342,13 +333,11 @@ class MessageSummary:
|
||||
self.LOG.warning(f"无法确定消息的chat_id (msg.id={msg.id}), 跳过记录")
|
||||
return
|
||||
|
||||
# ---- 获取 sender_wxid ----
|
||||
sender_wxid = msg.sender
|
||||
if not sender_wxid:
|
||||
# 理论上不应发生,但做个防护
|
||||
self.LOG.error(f"消息 (id={msg.id}) 缺少 sender wxid,无法记录!")
|
||||
return
|
||||
# ---- 获取 sender_wxid 结束 ----
|
||||
|
||||
# 确定发送者名称 (逻辑不变)
|
||||
sender_name = ""
|
||||
@@ -467,9 +456,6 @@ class MessageSummary:
|
||||
# 获取当前时间字符串 (使用完整格式)
|
||||
current_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
# ---- 修改记录调用 ----
|
||||
self.LOG.debug(f"记录消息 (来源: {source_info}, 类型: {'群聊' if msg.from_group() else '私聊'}): '[{current_time_str}]{sender_name}({sender_wxid}): {content_to_record}' (来自 msg.id={msg.id})")
|
||||
# 调用 record_message 时传入 sender_wxid
|
||||
self.record_message(chat_id, sender_name, sender_wxid, content_to_record, current_time_str)
|
||||
# ---- 记录调用修改结束 ----
|
||||
# ---- process_message_from_wxmsg 修改结束 ----
|
||||
|
||||
2
robot.py
2
robot.py
@@ -288,7 +288,6 @@ class Robot(Job):
|
||||
self.LOG.info(f"To {receiver}:\n{ats}\n{msg}")
|
||||
self.wcf.send_text(full_msg_content, receiver, at_list)
|
||||
|
||||
# ---- 修改记录逻辑 ----
|
||||
if self.message_summary: # 检查 message_summary 是否初始化成功
|
||||
# 确定机器人的名字
|
||||
robot_name = self.allContacts.get(self.wxid, "机器人")
|
||||
@@ -303,7 +302,6 @@ class Robot(Job):
|
||||
self.LOG.debug(f"已记录机器人发送的消息到 {receiver}")
|
||||
else:
|
||||
self.LOG.warning("MessageSummary 未初始化,无法记录发送的消息")
|
||||
# ---- 记录逻辑修改结束 ----
|
||||
|
||||
except Exception as e:
|
||||
self.LOG.error(f"发送消息失败: {e}")
|
||||
|
||||
Reference in New Issue
Block a user