diff --git a/ai_providers/ai_chatgpt.py b/ai_providers/ai_chatgpt.py index 7184c3a..4e0c0b2 100644 --- a/ai_providers/ai_chatgpt.py +++ b/ai_providers/ai_chatgpt.py @@ -7,6 +7,7 @@ import base64 import os from datetime import datetime import time # 引入 time 模块 +import json import httpx from openai import APIConnectionError, APIError, AuthenticationError, OpenAI @@ -25,7 +26,7 @@ class ChatGPT(): proxy = conf.get("proxy") prompt = conf.get("prompt") self.model = conf.get("model", "gpt-3.5-turbo") - self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置,默认10条 + self.max_history_messages = conf.get("max_history_messages", 30) # 默认读取最近30条历史 self.LOG = logging.getLogger("ChatGPT") # 存储传入的实例和wxid @@ -56,7 +57,17 @@ class ChatGPT(): return True return False - def get_answer(self, question: str, wxid: str, system_prompt_override=None, specific_max_history=None) -> str: + def get_answer( + self, + question: str, + wxid: str, + system_prompt_override=None, + specific_max_history=None, + tools=None, + tool_handler=None, + tool_choice=None, + tool_max_iterations: int = 10 + ) -> str: # 获取并格式化数据库历史记录 api_messages = [] @@ -104,37 +115,123 @@ class ChatGPT(): if question: # 确保问题非空 api_messages.append({"role": "user", "content": question}) - rsp = "" + if tools and not tool_handler: + # 如果提供了工具但没有处理器,则忽略工具以避免陷入死循环 + self.LOG.warning("tools 提供但没有 tool_handler,忽略工具定义。") + tools = None + try: - # 使用格式化后的 api_messages - params = { - "model": self.model, - "messages": api_messages # 使用从数据库构建的消息列表 - } - - # 只有非o系列模型才设置temperature - if not self.model.startswith("o"): - params["temperature"] = 0.2 - - ret = self.client.chat.completions.create(**params) - rsp = ret.choices[0].message.content - rsp = rsp[2:] if rsp.startswith("\n\n") else rsp - rsp = rsp.replace("\n\n", "\n") + response_text = self._execute_with_tools( + api_messages=api_messages, + tools=tools, + tool_handler=tool_handler, + tool_choice=tool_choice, + tool_max_iterations=tool_max_iterations + ) + return response_text except AuthenticationError: self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确") - rsp = "API认证失败" + return "API认证失败" except APIConnectionError: self.LOG.error("无法连接到 OpenAI API,请检查网络连接") - rsp = "网络连接错误" + return "网络连接错误" except APIError as e1: self.LOG.error(f"OpenAI API 返回了错误:{str(e1)}") - rsp = f"API错误: {str(e1)}" + return f"API错误: {str(e1)}" except Exception as e0: - self.LOG.error(f"发生未知错误:{str(e0)}") - rsp = "发生未知错误" + self.LOG.error(f"发生未知错误:{str(e0)}", exc_info=True) + return "发生未知错误" - return rsp + def _execute_with_tools( + self, + api_messages, + tools=None, + tool_handler=None, + tool_choice=None, + tool_max_iterations: int = 10 + ) -> str: + """执行带工具调用的对话逻辑""" + iterations = 0 + params_base = {"model": self.model} + + # 只有非o系列模型才设置temperature + if not self.model.startswith("o"): + params_base["temperature"] = 0.2 + + # 确保工具参数格式正确 + runtime_tools = tools if tools and isinstance(tools, list) else None + runtime_tool_choice = tool_choice + + while True: + params = dict(params_base) + params["messages"] = api_messages + if runtime_tools: + params["tools"] = runtime_tools + if runtime_tool_choice: + params["tool_choice"] = runtime_tool_choice + + ret = self.client.chat.completions.create(**params) + choice = ret.choices[0] + message = choice.message + finish_reason = choice.finish_reason + + if ( + runtime_tools + and message + and getattr(message, "tool_calls", None) + and finish_reason == "tool_calls" + and tool_handler + ): + iterations += 1 + api_messages.append({ + "role": "assistant", + "content": message.content or "", + "tool_calls": message.tool_calls + }) + + if tool_max_iterations is not None and iterations > max(tool_max_iterations, 0): + api_messages.append({ + "role": "system", + "content": "你已经达到可使用搜索历史工具的最大次数,请停止继续调用该工具,直接根据目前掌握的信息给出最终回答。" + }) + runtime_tool_choice = "none" + continue + + for tool_call in message.tool_calls: + tool_name = tool_call.function.name + raw_arguments = tool_call.function.arguments or "{}" + try: + parsed_arguments = json.loads(raw_arguments) + except json.JSONDecodeError: + parsed_arguments = {"_raw": raw_arguments} + + try: + tool_output = tool_handler(tool_name, parsed_arguments) + except Exception as handler_exc: + self.LOG.error(f"工具 {tool_name} 执行失败: {handler_exc}", exc_info=True) + tool_output = json.dumps( + {"error": f"{tool_name} failed: {handler_exc.__class__.__name__}"}, + ensure_ascii=False + ) + + if not isinstance(tool_output, str): + tool_output = json.dumps(tool_output, ensure_ascii=False) + + api_messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": tool_output + }) + + runtime_tool_choice = None + continue + + response_text = message.content if message and message.content else "" + if response_text.startswith("\n\n"): + response_text = response_text[2:] + response_text = response_text.replace("\n\n", "\n") + return response_text def encode_image_to_base64(self, image_path: str) -> str: """将图片文件转换为Base64编码 @@ -226,4 +323,4 @@ if __name__ == "__main__": # --- 测试代码需要调整 --- # 需要模拟 MessageSummary 和提供 bot_wxid 才能测试 print("请注意:直接运行此文件进行测试需要模拟 MessageSummary 并提供 bot_wxid。") - pass # 避免直接运行时出错 \ No newline at end of file + pass # 避免直接运行时出错 diff --git a/ai_providers/ai_deepseek.py b/ai_providers/ai_deepseek.py index 2e11e1c..502eb43 100644 --- a/ai_providers/ai_deepseek.py +++ b/ai_providers/ai_deepseek.py @@ -5,6 +5,7 @@ import logging from datetime import datetime import time # 引入 time 模块 +import json import httpx from openai import APIConnectionError, APIError, AuthenticationError, OpenAI @@ -23,7 +24,7 @@ class DeepSeek(): prompt = conf.get("prompt") self.model = conf.get("model", "deepseek-chat") # 读取最大历史消息数配置 - self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置,默认10条 + self.max_history_messages = conf.get("max_history_messages", 30) # 默认使用最近30条历史 self.LOG = logging.getLogger("DeepSeek") # 存储传入的实例和wxid @@ -52,7 +53,17 @@ class DeepSeek(): return True return False - def get_answer(self, question: str, wxid: str, system_prompt_override=None, specific_max_history=None) -> str: + def get_answer( + self, + question: str, + wxid: str, + system_prompt_override=None, + specific_max_history=None, + tools=None, + tool_handler=None, + tool_choice=None, + tool_max_iterations: int = 10 + ) -> str: # 获取并格式化数据库历史记录 api_messages = [] @@ -100,27 +111,109 @@ class DeepSeek(): if question: api_messages.append({"role": "user", "content": question}) + if tools and not tool_handler: + self.LOG.warning("tools 提供但未传入 tool_handler,忽略工具配置。") + tools = None + try: - # 使用格式化后的 api_messages - response = self.client.chat.completions.create( - model=self.model, - messages=api_messages, # 使用构建的消息列表 - stream=False + final_response = self._execute_with_tools( + api_messages=api_messages, + tools=tools, + tool_handler=tool_handler, + tool_choice=tool_choice, + tool_max_iterations=tool_max_iterations ) - final_response = response.choices[0].message.content - - return final_response except (APIConnectionError, APIError, AuthenticationError) as e1: self.LOG.error(f"DeepSeek API 返回了错误:{str(e1)}") return f"DeepSeek API 返回了错误:{str(e1)}" except Exception as e0: - self.LOG.error(f"发生未知错误:{str(e0)}") + self.LOG.error(f"发生未知错误:{str(e0)}", exc_info=True) return "抱歉,处理您的请求时出现了错误" + def _execute_with_tools( + self, + api_messages, + tools=None, + tool_handler=None, + tool_choice=None, + tool_max_iterations: int = 10 + ) -> str: + iterations = 0 + params_base = {"model": self.model, "stream": False} + + runtime_tools = tools if tools and isinstance(tools, list) else None + runtime_tool_choice = tool_choice + + while True: + params = dict(params_base) + params["messages"] = api_messages + if runtime_tools: + params["tools"] = runtime_tools + if runtime_tool_choice: + params["tool_choice"] = runtime_tool_choice + + response = self.client.chat.completions.create(**params) + choice = response.choices[0] + message = choice.message + finish_reason = choice.finish_reason + + if ( + runtime_tools + and message + and getattr(message, "tool_calls", None) + and finish_reason == "tool_calls" + and tool_handler + ): + iterations += 1 + api_messages.append({ + "role": "assistant", + "content": message.content or "", + "tool_calls": message.tool_calls + }) + + if tool_max_iterations is not None and iterations > max(tool_max_iterations, 0): + api_messages.append({ + "role": "system", + "content": "你已经达到允许的最大搜索次数,请停止继续调用搜索工具,根据现有信息完成回答。" + }) + runtime_tool_choice = "none" + continue + + for tool_call in message.tool_calls: + tool_name = tool_call.function.name + raw_arguments = tool_call.function.arguments or "{}" + try: + parsed_arguments = json.loads(raw_arguments) + except json.JSONDecodeError: + parsed_arguments = {"_raw": raw_arguments} + + try: + tool_output = tool_handler(tool_name, parsed_arguments) + except Exception as handler_exc: + self.LOG.error(f"工具 {tool_name} 执行失败: {handler_exc}", exc_info=True) + tool_output = json.dumps( + {"error": f"{tool_name} failed: {handler_exc.__class__.__name__}"}, + ensure_ascii=False + ) + + if not isinstance(tool_output, str): + tool_output = json.dumps(tool_output, ensure_ascii=False) + + api_messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": tool_output + }) + + runtime_tool_choice = None + continue + + return message.content if message and message.content else "" + if __name__ == "__main__": # --- 测试代码需要调整 --- print("请注意:直接运行此文件进行测试需要模拟 MessageSummary 并提供 bot_wxid。") - pass \ No newline at end of file + pass diff --git a/ai_providers/ai_gemini.py b/ai_providers/ai_gemini.py index a99a5eb..56e4850 100644 --- a/ai_providers/ai_gemini.py +++ b/ai_providers/ai_gemini.py @@ -21,7 +21,7 @@ except ImportError: class Gemini: DEFAULT_MODEL = "gemini-1.5-pro-latest" DEFAULT_PROMPT = "You are a helpful assistant." - DEFAULT_MAX_HISTORY = 15 + DEFAULT_MAX_HISTORY = 30 SAFETY_SETTINGS = { # 默认安全设置 - 可根据需要调整或从配置加载 safety_types.HarmCategory.HARM_CATEGORY_HARASSMENT: safety_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, safety_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: safety_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, @@ -206,7 +206,17 @@ class Gemini: return rsp_text.strip() - def get_answer(self, question: str, wxid: str, system_prompt_override=None, specific_max_history=None) -> str: + def get_answer( + self, + question: str, + wxid: str, + system_prompt_override=None, + specific_max_history=None, + tools=None, + tool_handler=None, + tool_choice=None, + tool_max_iterations: int = 10 + ) -> str: if not self._model: return "Gemini 模型未成功初始化,请检查配置和网络。" @@ -214,6 +224,9 @@ class Gemini: self.LOG.warning(f"尝试为 wxid={wxid} 获取答案,但问题为空。") return "您没有提问哦。" + if tools: + self.LOG.debug("Gemini 提供的实现暂不支持工具调用,请忽略 tools 参数。") + # 1. 准备历史消息 contents = [] if self.message_summary and self.bot_wxid: @@ -395,4 +408,4 @@ if __name__ == "__main__": else: print("\n--- Gemini 初始化失败,跳过图片描述测试 ---") - print("\n--- Gemini 本地测试结束 ---") \ No newline at end of file + print("\n--- Gemini 本地测试结束 ---") diff --git a/commands/ai_router.py b/commands/ai_router.py index 7d9478b..cfce477 100644 --- a/commands/ai_router.py +++ b/commands/ai_router.py @@ -7,7 +7,7 @@ from .context import MessageContext logger = logging.getLogger(__name__) -ROUTING_HISTORY_LIMIT = 10 +ROUTING_HISTORY_LIMIT = 30 CHAT_HISTORY_MIN = 10 CHAT_HISTORY_MAX = 300 @@ -85,8 +85,7 @@ class AIRouter: 1. 如果用户只是聊天或者不匹配任何功能,返回: { - "action_type": "chat", - "history_messages": 25 # 你认为闲聊需要的历史条数,介于10-300之间 + "action_type": "chat" } 2.如果用户需要使用上述功能之一,返回: @@ -99,14 +98,13 @@ class AIRouter: #### 示例: - 用户输入"提醒我下午3点开会" -> {"action_type": "function", "function_name": "reminder_set", "params": "下午3点开会"} - 用户输入"查看我的提醒" -> {"action_type": "function", "function_name": "reminder_list", "params": ""} - - 用户输入"你好" -> {"action_type": "chat", "history_messages": 15} + - 用户输入"你好" -> {"action_type": "chat"} - 用户输入"查一下Python教程" -> {"action_type": "function", "function_name": "perplexity_search", "params": "Python教程"} #### 格式注意事项: 1. action_type 只能是 "function" 或 "chat" 2. 只返回JSON,无需其他解释 3. function_name 必须完全匹配上述功能列表中的名称 - 4. 当 action_type 是 "chat" 时,必须提供整数 history_messages 字段,范围为10-300 """ return prompt @@ -172,17 +170,7 @@ class AIRouter: self.logger.warning(f"AI路由器:未知的功能名 - {function_name}") return False, None - if action_type == "chat": - history_value = decision.get("history_messages") - try: - history_value = int(history_value) - except (TypeError, ValueError): - history_value = CHAT_HISTORY_MIN - history_value = max(CHAT_HISTORY_MIN, min(CHAT_HISTORY_MAX, history_value)) - decision["history_messages"] = history_value - self.logger.info(f"AI路由决策: {decision}") - else: - self.logger.info(f"AI路由决策: {decision}") + self.logger.info(f"AI路由决策: {decision}") return True, decision except json.JSONDecodeError as e: @@ -245,14 +233,7 @@ class AIRouter: # 如果是聊天,返回False让后续处理器处理 if action_type == "chat": - history_limit = decision.get("history_messages", CHAT_HISTORY_MIN) - try: - history_limit = int(history_limit) - except (TypeError, ValueError): - history_limit = CHAT_HISTORY_MIN - history_limit = max(CHAT_HISTORY_MIN, min(CHAT_HISTORY_MAX, history_limit)) - setattr(ctx, 'specific_max_history', history_limit) - self.logger.info(f"AI路由器:识别为聊天意图,交给聊天处理器,使用历史条数 {history_limit}") + self.logger.info("AI路由器:识别为聊天意图,交给聊天处理器处理。") return False # 如果是功能调用 diff --git a/commands/handlers.py b/commands/handlers.py index d367e69..32b3a0f 100644 --- a/commands/handlers.py +++ b/commands/handlers.py @@ -9,6 +9,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from .context import MessageContext +DEFAULT_CHAT_HISTORY = 30 + def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool: """ 处理闲聊,调用AI模型生成回复 @@ -39,8 +41,10 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool: specific_max_history = 10 elif specific_max_history > 300: specific_max_history = 300 - setattr(ctx, 'specific_max_history', specific_max_history) - if ctx.logger and specific_max_history is not None: + if specific_max_history is None: + specific_max_history = DEFAULT_CHAT_HISTORY + setattr(ctx, 'specific_max_history', specific_max_history) + if ctx.logger: ctx.logger.debug(f"为 {ctx.get_receiver()} 使用特定历史限制: {specific_max_history}") # 处理引用图片情况 @@ -150,10 +154,128 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool: ctx.logger.info(f"【发送内容】将以下消息发送给AI: \n{q_with_info}") # 调用AI模型,传递特定历史限制 + tools = None + tool_handler = None + + if ctx.robot and getattr(ctx.robot, 'message_summary', None): + chat_id = ctx.get_receiver() + message_summary = ctx.robot.message_summary + + search_history_tool = { + "type": "function", + "function": { + "name": "search_chat_history", + "description": ( + "Search recent conversation history for specific keywords. " + "Returns at most 20 recent segments, each including the matched message " + "with five surrounding messages (if available) and timestamps. " + "Use this tool when you need precise historical context." + ), + "parameters": { + "type": "object", + "properties": { + "keywords": { + "type": "array", + "description": "List of keywords to search for in message content.", + "items": {"type": "string"}, + "minItems": 1 + }, + "query": { + "type": "string", + "description": "Optional free-form string; will be split into keywords by whitespace." + }, + "context_window": { + "type": "integer", + "description": "How many messages before and after each match to include (default 5, max 10).", + "minimum": 0, + "maximum": 10 + }, + "max_results": { + "type": "integer", + "description": "Maximum number of result segments to return (default 20).", + "minimum": 1, + "maximum": 20 + } + }, + "additionalProperties": False + } + } + } + + def handle_tool_call(tool_name: str, arguments: Dict[str, Any]) -> str: + if tool_name != "search_chat_history": + return json.dumps({"error": f"Unknown tool '{tool_name}'"}, ensure_ascii=False) + + try: + keywords = arguments.get("keywords", []) + if isinstance(keywords, str): + keywords = [keywords] + elif not isinstance(keywords, list): + keywords = [] + + query = arguments.get("query") + if isinstance(query, str) and query.strip(): + query_keywords = [segment for segment in query.strip().split() if segment] + keywords.extend(query_keywords) + + cleaned_keywords = [] + for kw in keywords: + if kw is None: + continue + kw_str = str(kw).strip() + if kw_str: + cleaned_keywords.append(kw_str) + + # 去重同时保持顺序 + seen = set() + deduped_keywords = [] + for kw in cleaned_keywords: + lower_kw = kw.lower() + if lower_kw not in seen: + seen.add(lower_kw) + deduped_keywords.append(kw) + + if not deduped_keywords: + return json.dumps({"error": "No valid keywords provided.", "results": []}, ensure_ascii=False) + + context_window = arguments.get("context_window", 5) + max_results = arguments.get("max_results", 20) + + search_results = message_summary.search_messages_with_context( + chat_id=chat_id, + keywords=deduped_keywords, + context_window=context_window, + max_groups=max_results + ) + + response_payload = { + "results": search_results, + "returned_groups": len(search_results), + "keywords": deduped_keywords + } + + if not search_results: + response_payload["notice"] = "No messages matched the provided keywords." + + return json.dumps(response_payload, ensure_ascii=False) + except Exception as tool_exc: + if ctx.logger: + ctx.logger.error(f"搜索历史工具调用失败: {tool_exc}", exc_info=True) + return json.dumps( + {"error": f"Search failed: {tool_exc.__class__.__name__}"}, + ensure_ascii=False + ) + + tools = [search_history_tool] + tool_handler = handle_tool_call + rsp = chat_model.get_answer( question=q_with_info, wxid=ctx.get_receiver(), - specific_max_history=specific_max_history + specific_max_history=specific_max_history, + tools=tools, + tool_handler=tool_handler, + tool_max_iterations=10 ) if rsp: diff --git a/function/func_summary.py b/function/func_summary.py index 40c6da1..34df439 100644 --- a/function/func_summary.py +++ b/function/func_summary.py @@ -223,6 +223,101 @@ class MessageSummary: return messages + def search_messages_with_context(self, chat_id, keywords, context_window=5, max_groups=20): + """根据关键词搜索消息,返回包含前后上下文的结果 + + Args: + chat_id (str): 聊天ID(群ID或用户ID) + keywords (Union[str, list[str]]): 需要搜索的关键词或关键词列表 + context_window (int): 每条匹配消息前后额外提供的消息数量 + max_groups (int): 返回的最多结果组数(按时间倒序,优先最新消息) + + Returns: + list[dict]: 搜索结果列表,每个元素包含匹配关键词、锚点消息及上下文消息 + """ + if not keywords: + return [] + + if isinstance(keywords, str): + keywords = [keywords] + + normalized_keywords = [] + for kw in keywords: + if kw is None: + continue + kw_str = str(kw).strip() + if kw_str: + normalized_keywords.append((kw_str, kw_str.lower())) + + if not normalized_keywords: + return [] + + try: + context_window = int(context_window) + except (TypeError, ValueError): + context_window = 5 + context_window = max(0, min(context_window, 10)) # 限制上下文窗口大小,避免过长 + + try: + max_groups = int(max_groups) + except (TypeError, ValueError): + max_groups = 20 + max_groups = max(1, min(max_groups, 20)) + + messages = self.get_messages(chat_id) + if not messages: + return [] + + results = [] + total_messages = len(messages) + used_indices = set() + + for idx in range(total_messages - 1, -1, -1): + message = messages[idx] + content = message.get("content", "") + if not content: + continue + + lower_content = content.lower() + matched_keywords = [orig for orig, lower in normalized_keywords if lower in lower_content] + if not matched_keywords: + continue + + if idx in used_indices: + continue + + start = max(0, idx - context_window) + end = min(total_messages, idx + context_window + 1) + segment_messages = [] + + for pos in range(start, end): + msg = messages[pos] + segment_messages.append({ + "time": msg.get("time"), + "sender": msg.get("sender"), + "sender_wxid": msg.get("sender_wxid"), + "content": msg.get("content"), + "relative_offset": pos - idx, + "is_match": pos == idx + }) + + results.append({ + "matched_keywords": matched_keywords, + "anchor_index": idx, + "anchor_time": message.get("time"), + "anchor_sender": message.get("sender"), + "anchor_sender_wxid": message.get("sender_wxid"), + "messages": segment_messages + }) + + for off in range(start, end): + used_indices.add(off) + + if len(results) >= max_groups: + break + + return results + def _basic_summarize(self, messages): """基本的消息总结逻辑,不使用AI