重构 ChatGPT 和 DeepSeek 类的初始化逻辑,新增最大历史消息数配置,优化消息处理逻辑,移除冗余代码,提升代码可读性。同时更新消息记录逻辑,确保消息记录功能正常。

This commit is contained in:
Zylan
2025-04-24 13:09:56 +08:00
parent 61cc9c89cc
commit 8227802eb0
5 changed files with 15 additions and 81 deletions

View File

@@ -19,39 +19,30 @@ except ImportError:
class ChatGPT():
# ---- 修改 __init__ ----
def __init__(self, conf: dict, message_summary_instance: MessageSummary = None, bot_wxid: str = None) -> None:
key = conf.get("key")
api = conf.get("api")
proxy = conf.get("proxy")
prompt = conf.get("prompt")
self.model = conf.get("model", "gpt-3.5-turbo")
# ---- 新增:读取最大历史消息数配置 ----
self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置默认10条
# ---- 新增结束 ----
self.LOG = logging.getLogger("ChatGPT")
# ---- 存储传入的实例和wxid ----
# 存储传入的实例和wxid
self.message_summary = message_summary_instance
self.bot_wxid = bot_wxid
if not self.message_summary:
self.LOG.warning("MessageSummary 实例未提供给 ChatGPT上下文功能将不可用")
if not self.bot_wxid:
self.LOG.warning("bot_wxid 未提供给 ChatGPT可能无法正确识别机器人自身消息")
# ---- 存储结束 ----
if proxy:
self.client = OpenAI(api_key=key, base_url=api, http_client=httpx.Client(proxy=proxy))
else:
self.client = OpenAI(api_key=key, base_url=api)
# ---- 移除 self.conversation_list ----
# self.conversation_list = {}
# ---- 移除结束 ----
self.system_content_msg = {"role": "system", "content": prompt if prompt else "You are a helpful assistant."} # 提供默认值
self.support_vision = self.model == "gpt-4-vision-preview" or self.model == "gpt-4o" or "-vision" in self.model
# ---- __init__ 修改结束 ----
def __repr__(self):
return 'ChatGPT'
@@ -60,18 +51,13 @@ class ChatGPT():
def value_check(conf: dict) -> bool:
# 不再检查 prompt因为可以没有默认 prompt
if conf:
# ---- 修改:也检查 max_history_messages (虽然有默认值) ----
# 也检查 max_history_messages (虽然有默认值)
if conf.get("key") and conf.get("api"): # and conf.get("max_history_messages") is not None: # 如果需要强制配置
return True
return False
# ---- 修改 get_answer ----
def get_answer(self, question: str, wxid: str, system_prompt_override=None) -> str:
# ---- 移除 #清除对话 逻辑 ----
# ... (代码已移除) ...
# ---- 移除结束 ----
# ---- 获取并格式化数据库历史记录 ----
# 获取并格式化数据库历史记录
api_messages = []
# 1. 添加系统提示
@@ -89,12 +75,11 @@ class ChatGPT():
if self.message_summary and self.bot_wxid:
history = self.message_summary.get_messages(wxid)
# ---- 新增:限制历史消息数量 ----
# -限制历史消息数量
if self.max_history_messages is not None and self.max_history_messages > 0:
history = history[-self.max_history_messages:] # 取最新的 N 条
elif self.max_history_messages == 0: # 如果设置为0则不包含历史
history = []
# ---- 新增结束 ----
for msg in history:
role = "assistant" if msg.get("sender_wxid") == self.bot_wxid else "user"
@@ -107,11 +92,10 @@ class ChatGPT():
# 3. 添加当前用户问题
if question: # 确保问题非空
api_messages.append({"role": "user", "content": question})
# ---- 获取和格式化结束 ----
rsp = ""
try:
# ---- 使用格式化后的 api_messages ----
# 使用格式化后的 api_messages
params = {
"model": self.model,
"messages": api_messages # 使用从数据库构建的消息列表
@@ -126,10 +110,6 @@ class ChatGPT():
rsp = rsp[2:] if rsp.startswith("\n\n") else rsp
rsp = rsp.replace("\n\n", "\n")
# ---- 移除 updateMessage 调用 ----
# ... (代码已移除) ...
# ---- 移除结束 ----
except AuthenticationError:
self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确")
rsp = "API认证失败"
@@ -144,7 +124,6 @@ class ChatGPT():
rsp = "发生未知错误"
return rsp
# ---- get_answer 修改结束 ----
def encode_image_to_base64(self, image_path: str) -> str:
"""将图片文件转换为Base64编码
@@ -231,10 +210,6 @@ class ChatGPT():
self.LOG.error(f"分析图片时发生未知错误:{str(e0)}")
return f"处理图片时出错:{str(e0)}"
# ---- 移除 updateMessage ----
# ... (代码已移除) ...
# ---- 移除结束 ----
if __name__ == "__main__":
# --- 测试代码需要调整 ---

View File

@@ -16,42 +16,30 @@ except ImportError:
MessageSummary = object
class DeepSeek():
# ---- 修改 __init__ ----
def __init__(self, conf: dict, message_summary_instance: MessageSummary = None, bot_wxid: str = None) -> None:
key = conf.get("key")
api = conf.get("api", "https://api.deepseek.com")
proxy = conf.get("proxy")
prompt = conf.get("prompt")
self.model = conf.get("model", "deepseek-chat")
# ---- 新增:读取最大历史消息数配置 ----
# 读取最大历史消息数配置
self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置默认10条
# ---- 新增结束 ----
self.LOG = logging.getLogger("DeepSeek")
# ---- 存储传入的实例和wxid ----
# 存储传入的实例和wxid
self.message_summary = message_summary_instance
self.bot_wxid = bot_wxid
if not self.message_summary:
self.LOG.warning("MessageSummary 实例未提供给 DeepSeek上下文功能将不可用")
if not self.bot_wxid:
self.LOG.warning("bot_wxid 未提供给 DeepSeek可能无法正确识别机器人自身消息")
# ---- 存储结束 ----
# ---- 移除思维链相关逻辑 ----
# ... (代码已移除) ...
# ---- 移除结束 ----
if proxy:
self.client = OpenAI(api_key=key, base_url=api, http_client=httpx.Client(proxy=proxy))
else:
self.client = OpenAI(api_key=key, base_url=api)
# ---- 移除 self.conversation_list ----
# ... (代码已移除) ...
# ---- 移除结束 ----
self.system_content_msg = {"role": "system", "content": prompt if prompt else "You are a helpful assistant."} # 提供默认值
# ---- __init__ 修改结束 ----
def __repr__(self):
return 'DeepSeek'
@@ -59,18 +47,13 @@ class DeepSeek():
@staticmethod
def value_check(conf: dict) -> bool:
if conf:
# ---- 修改:也检查 max_history_messages (虽然有默认值) ----
# 也检查 max_history_messages (虽然有默认值)
if conf.get("key"): # and conf.get("max_history_messages") is not None: # 如果需要强制配置
return True
return False
# ---- 修改 get_answer ----
def get_answer(self, question: str, wxid: str, system_prompt_override=None) -> str:
# ---- 移除 #清除对话 和 思维链命令 ----
# ... (代码已移除) ...
# ---- 移除结束 ----
# ---- 获取并格式化数据库历史记录 ----
# 获取并格式化数据库历史记录
api_messages = []
# 1. 添加系统提示
@@ -88,12 +71,11 @@ class DeepSeek():
if self.message_summary and self.bot_wxid:
history = self.message_summary.get_messages(wxid)
# ---- 新增:限制历史消息数量 ----
# 限制历史消息数量
if self.max_history_messages is not None and self.max_history_messages > 0:
history = history[-self.max_history_messages:] # 取最新的 N 条
elif self.max_history_messages == 0: # 如果设置为0则不包含历史
history = []
# ---- 新增结束 ----
for msg in history:
role = "assistant" if msg.get("sender_wxid") == self.bot_wxid else "user"
@@ -106,21 +88,16 @@ class DeepSeek():
# 3. 添加当前用户问题
if question:
api_messages.append({"role": "user", "content": question})
# ---- 获取和格式化结束 ----
try:
# ---- 使用格式化后的 api_messages ----
# 使用格式化后的 api_messages
response = self.client.chat.completions.create(
model=self.model,
messages=api_messages, # 使用构建的消息列表
stream=False
)
# ---- 移除思维链特殊处理和本地历史更新 ----
# ... (代码已移除) ...
final_response = response.choices[0].message.content
# ... (代码已移除) ...
# ---- 移除结束 ----
return final_response
@@ -130,7 +107,6 @@ class DeepSeek():
except Exception as e0:
self.LOG.error(f"发生未知错误:{str(e0)}")
return "抱歉,处理您的请求时出现了错误"
# ---- get_answer 修改结束 ----
if __name__ == "__main__":

View File

@@ -268,7 +268,7 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool:
ctx.send_text("抱歉,我现在无法进行对话。")
return False
# ---- 处理引用图片情况 ----
# 处理引用图片情况
if getattr(ctx, 'is_quoted_image', False):
ctx.logger.info("检测到引用图片消息,尝试处理图片内容...")
@@ -342,7 +342,6 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool:
ctx.logger.error(f"处理引用图片过程中出错: {e}")
ctx.send_text(f"处理图片时发生错误: {str(e)}")
return True # 已处理,即使出错也不执行后续普通文本处理
# ---- 引用图片处理结束 ----
# 获取消息内容
content = ctx.text
@@ -684,7 +683,7 @@ def handle_reminder(ctx: 'MessageContext', match: Optional[Match]) -> bool:
ctx.send_text("🤔 嗯... 我好像没太明白您想设置什么提醒,可以换种方式再说一次吗?", at_list)
return True
# ---- 批量处理提醒 ----
# 批量处理提醒
results = [] # 用于存储每个提醒的处理结果
roomid = ctx.msg.roomid if ctx.is_group else None
@@ -739,7 +738,7 @@ def handle_reminder(ctx: 'MessageContext', match: Optional[Match]) -> bool:
results.append({"label": reminder_label, "success": False, "error": validation_error, "data": data})
if ctx.logger: ctx.logger.warning(f"提醒数据验证失败 ({reminder_label}): {validation_error} - Data: {data}")
# ---- 构建汇总反馈消息 ----
# 构建汇总反馈消息
reply_parts = []
successful_count = sum(1 for res in results if res["success"])
failed_count = len(results) - successful_count

View File

@@ -39,7 +39,6 @@ class MessageSummary:
self.cursor = self.conn.cursor()
self.LOG.info(f"已连接到 SQLite 数据库: {self.db_path}")
# ---- 修改数据库表结构 ----
# 检查并添加 sender_wxid 列 (如果不存在)
self.cursor.execute("PRAGMA table_info(messages)")
columns = [col[1] for col in self.cursor.fetchall()]
@@ -67,7 +66,6 @@ class MessageSummary:
timestamp_str TEXT NOT NULL -- 存储完整时间格式 YYYY-MM-DD HH:MM:SS
)
""")
# ---- 数据库表结构修改结束 ----
self.cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_chat_time ON messages (chat_id, timestamp_float)
@@ -96,7 +94,6 @@ class MessageSummary:
except sqlite3.Error as e:
self.LOG.error(f"关闭数据库连接时出错: {e}")
# ---- 修改 record_message ----
def record_message(self, chat_id, sender_name, sender_wxid, content, timestamp=None):
"""记录单条消息到数据库
@@ -110,7 +107,6 @@ class MessageSummary:
try:
current_time_float = time.time()
# ---- 修改时间格式 ----
if not timestamp:
# 默认使用完整时间格式
timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(current_time_float))
@@ -132,7 +128,6 @@ class MessageSummary:
INSERT INTO messages (chat_id, sender, sender_wxid, content, timestamp_float, timestamp_str)
VALUES (?, ?, ?, ?, ?, ?)
""", (chat_id, sender_name, sender_wxid, content, current_time_float, timestamp_str))
# ---- 时间格式和插入修改结束 ----
# 删除超出 max_history 的旧消息
self.cursor.execute("""
@@ -154,7 +149,6 @@ class MessageSummary:
self.conn.rollback()
except:
pass
# ---- record_message 修改结束 ----
def clear_message_history(self, chat_id):
"""清除指定聊天的消息历史记录
@@ -194,7 +188,6 @@ class MessageSummary:
self.LOG.error(f"获取消息数量时出错 (chat_id={chat_id}): {e}")
return 0
# ---- 修改 get_messages ----
def get_messages(self, chat_id):
"""获取指定聊天的所有消息 (按时间升序)包含发送者wxid和完整时间戳
@@ -230,7 +223,6 @@ class MessageSummary:
self.LOG.error(f"获取消息列表时出错 (chat_id={chat_id}): {e}")
return messages
# ---- get_messages 修改结束 ----
def _basic_summarize(self, messages):
"""基本的消息总结逻辑不使用AI
@@ -322,7 +314,6 @@ class MessageSummary:
else:
return self._basic_summarize(messages)
# ---- 修改 process_message_from_wxmsg ----
def process_message_from_wxmsg(self, msg, wcf, all_contacts, bot_wxid=None):
"""从微信消息对象中处理并记录与总结相关的文本消息
记录所有群聊和私聊的文本(1)和App/卡片(49)消息。
@@ -342,13 +333,11 @@ class MessageSummary:
self.LOG.warning(f"无法确定消息的chat_id (msg.id={msg.id}), 跳过记录")
return
# ---- 获取 sender_wxid ----
sender_wxid = msg.sender
if not sender_wxid:
# 理论上不应发生,但做个防护
self.LOG.error(f"消息 (id={msg.id}) 缺少 sender wxid无法记录")
return
# ---- 获取 sender_wxid 结束 ----
# 确定发送者名称 (逻辑不变)
sender_name = ""
@@ -467,9 +456,6 @@ class MessageSummary:
# 获取当前时间字符串 (使用完整格式)
current_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
# ---- 修改记录调用 ----
self.LOG.debug(f"记录消息 (来源: {source_info}, 类型: {'群聊' if msg.from_group() else '私聊'}): '[{current_time_str}]{sender_name}({sender_wxid}): {content_to_record}' (来自 msg.id={msg.id})")
# 调用 record_message 时传入 sender_wxid
self.record_message(chat_id, sender_name, sender_wxid, content_to_record, current_time_str)
# ---- 记录调用修改结束 ----
# ---- process_message_from_wxmsg 修改结束 ----

View File

@@ -288,7 +288,6 @@ class Robot(Job):
self.LOG.info(f"To {receiver}:\n{ats}\n{msg}")
self.wcf.send_text(full_msg_content, receiver, at_list)
# ---- 修改记录逻辑 ----
if self.message_summary: # 检查 message_summary 是否初始化成功
# 确定机器人的名字
robot_name = self.allContacts.get(self.wxid, "机器人")
@@ -303,7 +302,6 @@ class Robot(Job):
self.LOG.debug(f"已记录机器人发送的消息到 {receiver}")
else:
self.LOG.warning("MessageSummary 未初始化,无法记录发送的消息")
# ---- 记录逻辑修改结束 ----
except Exception as e:
self.LOG.error(f"发送消息失败: {e}")