function call

This commit is contained in:
zihanjian
2025-10-13 14:42:39 +08:00
parent 16e62c8eec
commit 34055175bc
6 changed files with 467 additions and 66 deletions

View File

@@ -7,6 +7,7 @@ import base64
import os
from datetime import datetime
import time # 引入 time 模块
import json
import httpx
from openai import APIConnectionError, APIError, AuthenticationError, OpenAI
@@ -25,7 +26,7 @@ class ChatGPT():
proxy = conf.get("proxy")
prompt = conf.get("prompt")
self.model = conf.get("model", "gpt-3.5-turbo")
self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置默认10条
self.max_history_messages = conf.get("max_history_messages", 30) # 默认读取最近30条历史
self.LOG = logging.getLogger("ChatGPT")
# 存储传入的实例和wxid
@@ -56,7 +57,17 @@ class ChatGPT():
return True
return False
def get_answer(self, question: str, wxid: str, system_prompt_override=None, specific_max_history=None) -> str:
def get_answer(
self,
question: str,
wxid: str,
system_prompt_override=None,
specific_max_history=None,
tools=None,
tool_handler=None,
tool_choice=None,
tool_max_iterations: int = 10
) -> str:
# 获取并格式化数据库历史记录
api_messages = []
@@ -104,37 +115,123 @@ class ChatGPT():
if question: # 确保问题非空
api_messages.append({"role": "user", "content": question})
rsp = ""
if tools and not tool_handler:
# 如果提供了工具但没有处理器,则忽略工具以避免陷入死循环
self.LOG.warning("tools 提供但没有 tool_handler忽略工具定义。")
tools = None
try:
# 使用格式化后的 api_messages
params = {
"model": self.model,
"messages": api_messages # 使用从数据库构建的消息列表
}
# 只有非o系列模型才设置temperature
if not self.model.startswith("o"):
params["temperature"] = 0.2
ret = self.client.chat.completions.create(**params)
rsp = ret.choices[0].message.content
rsp = rsp[2:] if rsp.startswith("\n\n") else rsp
rsp = rsp.replace("\n\n", "\n")
response_text = self._execute_with_tools(
api_messages=api_messages,
tools=tools,
tool_handler=tool_handler,
tool_choice=tool_choice,
tool_max_iterations=tool_max_iterations
)
return response_text
except AuthenticationError:
self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确")
rsp = "API认证失败"
return "API认证失败"
except APIConnectionError:
self.LOG.error("无法连接到 OpenAI API请检查网络连接")
rsp = "网络连接错误"
return "网络连接错误"
except APIError as e1:
self.LOG.error(f"OpenAI API 返回了错误:{str(e1)}")
rsp = f"API错误: {str(e1)}"
return f"API错误: {str(e1)}"
except Exception as e0:
self.LOG.error(f"发生未知错误:{str(e0)}")
rsp = "发生未知错误"
self.LOG.error(f"发生未知错误:{str(e0)}", exc_info=True)
return "发生未知错误"
return rsp
def _execute_with_tools(
self,
api_messages,
tools=None,
tool_handler=None,
tool_choice=None,
tool_max_iterations: int = 10
) -> str:
"""执行带工具调用的对话逻辑"""
iterations = 0
params_base = {"model": self.model}
# 只有非o系列模型才设置temperature
if not self.model.startswith("o"):
params_base["temperature"] = 0.2
# 确保工具参数格式正确
runtime_tools = tools if tools and isinstance(tools, list) else None
runtime_tool_choice = tool_choice
while True:
params = dict(params_base)
params["messages"] = api_messages
if runtime_tools:
params["tools"] = runtime_tools
if runtime_tool_choice:
params["tool_choice"] = runtime_tool_choice
ret = self.client.chat.completions.create(**params)
choice = ret.choices[0]
message = choice.message
finish_reason = choice.finish_reason
if (
runtime_tools
and message
and getattr(message, "tool_calls", None)
and finish_reason == "tool_calls"
and tool_handler
):
iterations += 1
api_messages.append({
"role": "assistant",
"content": message.content or "",
"tool_calls": message.tool_calls
})
if tool_max_iterations is not None and iterations > max(tool_max_iterations, 0):
api_messages.append({
"role": "system",
"content": "你已经达到可使用搜索历史工具的最大次数,请停止继续调用该工具,直接根据目前掌握的信息给出最终回答。"
})
runtime_tool_choice = "none"
continue
for tool_call in message.tool_calls:
tool_name = tool_call.function.name
raw_arguments = tool_call.function.arguments or "{}"
try:
parsed_arguments = json.loads(raw_arguments)
except json.JSONDecodeError:
parsed_arguments = {"_raw": raw_arguments}
try:
tool_output = tool_handler(tool_name, parsed_arguments)
except Exception as handler_exc:
self.LOG.error(f"工具 {tool_name} 执行失败: {handler_exc}", exc_info=True)
tool_output = json.dumps(
{"error": f"{tool_name} failed: {handler_exc.__class__.__name__}"},
ensure_ascii=False
)
if not isinstance(tool_output, str):
tool_output = json.dumps(tool_output, ensure_ascii=False)
api_messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_output
})
runtime_tool_choice = None
continue
response_text = message.content if message and message.content else ""
if response_text.startswith("\n\n"):
response_text = response_text[2:]
response_text = response_text.replace("\n\n", "\n")
return response_text
def encode_image_to_base64(self, image_path: str) -> str:
"""将图片文件转换为Base64编码
@@ -226,4 +323,4 @@ if __name__ == "__main__":
# --- 测试代码需要调整 ---
# 需要模拟 MessageSummary 和提供 bot_wxid 才能测试
print("请注意:直接运行此文件进行测试需要模拟 MessageSummary 并提供 bot_wxid。")
pass # 避免直接运行时出错
pass # 避免直接运行时出错

View File

@@ -5,6 +5,7 @@
import logging
from datetime import datetime
import time # 引入 time 模块
import json
import httpx
from openai import APIConnectionError, APIError, AuthenticationError, OpenAI
@@ -23,7 +24,7 @@ class DeepSeek():
prompt = conf.get("prompt")
self.model = conf.get("model", "deepseek-chat")
# 读取最大历史消息数配置
self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置默认10条
self.max_history_messages = conf.get("max_history_messages", 30) # 默认使用最近30条历史
self.LOG = logging.getLogger("DeepSeek")
# 存储传入的实例和wxid
@@ -52,7 +53,17 @@ class DeepSeek():
return True
return False
def get_answer(self, question: str, wxid: str, system_prompt_override=None, specific_max_history=None) -> str:
def get_answer(
self,
question: str,
wxid: str,
system_prompt_override=None,
specific_max_history=None,
tools=None,
tool_handler=None,
tool_choice=None,
tool_max_iterations: int = 10
) -> str:
# 获取并格式化数据库历史记录
api_messages = []
@@ -100,27 +111,109 @@ class DeepSeek():
if question:
api_messages.append({"role": "user", "content": question})
if tools and not tool_handler:
self.LOG.warning("tools 提供但未传入 tool_handler忽略工具配置。")
tools = None
try:
# 使用格式化后的 api_messages
response = self.client.chat.completions.create(
model=self.model,
messages=api_messages, # 使用构建的消息列表
stream=False
final_response = self._execute_with_tools(
api_messages=api_messages,
tools=tools,
tool_handler=tool_handler,
tool_choice=tool_choice,
tool_max_iterations=tool_max_iterations
)
final_response = response.choices[0].message.content
return final_response
except (APIConnectionError, APIError, AuthenticationError) as e1:
self.LOG.error(f"DeepSeek API 返回了错误:{str(e1)}")
return f"DeepSeek API 返回了错误:{str(e1)}"
except Exception as e0:
self.LOG.error(f"发生未知错误:{str(e0)}")
self.LOG.error(f"发生未知错误:{str(e0)}", exc_info=True)
return "抱歉,处理您的请求时出现了错误"
def _execute_with_tools(
self,
api_messages,
tools=None,
tool_handler=None,
tool_choice=None,
tool_max_iterations: int = 10
) -> str:
iterations = 0
params_base = {"model": self.model, "stream": False}
runtime_tools = tools if tools and isinstance(tools, list) else None
runtime_tool_choice = tool_choice
while True:
params = dict(params_base)
params["messages"] = api_messages
if runtime_tools:
params["tools"] = runtime_tools
if runtime_tool_choice:
params["tool_choice"] = runtime_tool_choice
response = self.client.chat.completions.create(**params)
choice = response.choices[0]
message = choice.message
finish_reason = choice.finish_reason
if (
runtime_tools
and message
and getattr(message, "tool_calls", None)
and finish_reason == "tool_calls"
and tool_handler
):
iterations += 1
api_messages.append({
"role": "assistant",
"content": message.content or "",
"tool_calls": message.tool_calls
})
if tool_max_iterations is not None and iterations > max(tool_max_iterations, 0):
api_messages.append({
"role": "system",
"content": "你已经达到允许的最大搜索次数,请停止继续调用搜索工具,根据现有信息完成回答。"
})
runtime_tool_choice = "none"
continue
for tool_call in message.tool_calls:
tool_name = tool_call.function.name
raw_arguments = tool_call.function.arguments or "{}"
try:
parsed_arguments = json.loads(raw_arguments)
except json.JSONDecodeError:
parsed_arguments = {"_raw": raw_arguments}
try:
tool_output = tool_handler(tool_name, parsed_arguments)
except Exception as handler_exc:
self.LOG.error(f"工具 {tool_name} 执行失败: {handler_exc}", exc_info=True)
tool_output = json.dumps(
{"error": f"{tool_name} failed: {handler_exc.__class__.__name__}"},
ensure_ascii=False
)
if not isinstance(tool_output, str):
tool_output = json.dumps(tool_output, ensure_ascii=False)
api_messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_output
})
runtime_tool_choice = None
continue
return message.content if message and message.content else ""
if __name__ == "__main__":
# --- 测试代码需要调整 ---
print("请注意:直接运行此文件进行测试需要模拟 MessageSummary 并提供 bot_wxid。")
pass
pass

View File

@@ -21,7 +21,7 @@ except ImportError:
class Gemini:
DEFAULT_MODEL = "gemini-1.5-pro-latest"
DEFAULT_PROMPT = "You are a helpful assistant."
DEFAULT_MAX_HISTORY = 15
DEFAULT_MAX_HISTORY = 30
SAFETY_SETTINGS = { # 默认安全设置 - 可根据需要调整或从配置加载
safety_types.HarmCategory.HARM_CATEGORY_HARASSMENT: safety_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
safety_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: safety_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
@@ -206,7 +206,17 @@ class Gemini:
return rsp_text.strip()
def get_answer(self, question: str, wxid: str, system_prompt_override=None, specific_max_history=None) -> str:
def get_answer(
self,
question: str,
wxid: str,
system_prompt_override=None,
specific_max_history=None,
tools=None,
tool_handler=None,
tool_choice=None,
tool_max_iterations: int = 10
) -> str:
if not self._model:
return "Gemini 模型未成功初始化,请检查配置和网络。"
@@ -214,6 +224,9 @@ class Gemini:
self.LOG.warning(f"尝试为 wxid={wxid} 获取答案,但问题为空。")
return "您没有提问哦。"
if tools:
self.LOG.debug("Gemini 提供的实现暂不支持工具调用,请忽略 tools 参数。")
# 1. 准备历史消息
contents = []
if self.message_summary and self.bot_wxid:
@@ -395,4 +408,4 @@ if __name__ == "__main__":
else:
print("\n--- Gemini 初始化失败,跳过图片描述测试 ---")
print("\n--- Gemini 本地测试结束 ---")
print("\n--- Gemini 本地测试结束 ---")

View File

@@ -7,7 +7,7 @@ from .context import MessageContext
logger = logging.getLogger(__name__)
ROUTING_HISTORY_LIMIT = 10
ROUTING_HISTORY_LIMIT = 30
CHAT_HISTORY_MIN = 10
CHAT_HISTORY_MAX = 300
@@ -85,8 +85,7 @@ class AIRouter:
1. 如果用户只是聊天或者不匹配任何功能,返回:
{
"action_type": "chat",
"history_messages": 25 # 你认为闲聊需要的历史条数介于10-300之间
"action_type": "chat"
}
2.如果用户需要使用上述功能之一,返回:
@@ -99,14 +98,13 @@ class AIRouter:
#### 示例:
- 用户输入"提醒我下午3点开会" -> {"action_type": "function", "function_name": "reminder_set", "params": "下午3点开会"}
- 用户输入"查看我的提醒" -> {"action_type": "function", "function_name": "reminder_list", "params": ""}
- 用户输入"你好" -> {"action_type": "chat", "history_messages": 15}
- 用户输入"你好" -> {"action_type": "chat"}
- 用户输入"查一下Python教程" -> {"action_type": "function", "function_name": "perplexity_search", "params": "Python教程"}
#### 格式注意事项:
1. action_type 只能是 "function""chat"
2. 只返回JSON无需其他解释
3. function_name 必须完全匹配上述功能列表中的名称
4. 当 action_type 是 "chat" 时,必须提供整数 history_messages 字段范围为10-300
"""
return prompt
@@ -172,17 +170,7 @@ class AIRouter:
self.logger.warning(f"AI路由器未知的功能名 - {function_name}")
return False, None
if action_type == "chat":
history_value = decision.get("history_messages")
try:
history_value = int(history_value)
except (TypeError, ValueError):
history_value = CHAT_HISTORY_MIN
history_value = max(CHAT_HISTORY_MIN, min(CHAT_HISTORY_MAX, history_value))
decision["history_messages"] = history_value
self.logger.info(f"AI路由决策: {decision}")
else:
self.logger.info(f"AI路由决策: {decision}")
self.logger.info(f"AI路由决策: {decision}")
return True, decision
except json.JSONDecodeError as e:
@@ -245,14 +233,7 @@ class AIRouter:
# 如果是聊天返回False让后续处理器处理
if action_type == "chat":
history_limit = decision.get("history_messages", CHAT_HISTORY_MIN)
try:
history_limit = int(history_limit)
except (TypeError, ValueError):
history_limit = CHAT_HISTORY_MIN
history_limit = max(CHAT_HISTORY_MIN, min(CHAT_HISTORY_MAX, history_limit))
setattr(ctx, 'specific_max_history', history_limit)
self.logger.info(f"AI路由器识别为聊天意图交给聊天处理器使用历史条数 {history_limit}")
self.logger.info("AI路由器识别为聊天意图交给聊天处理器处理。")
return False
# 如果是功能调用

View File

@@ -9,6 +9,8 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .context import MessageContext
DEFAULT_CHAT_HISTORY = 30
def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool:
"""
处理闲聊调用AI模型生成回复
@@ -39,8 +41,10 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool:
specific_max_history = 10
elif specific_max_history > 300:
specific_max_history = 300
setattr(ctx, 'specific_max_history', specific_max_history)
if ctx.logger and specific_max_history is not None:
if specific_max_history is None:
specific_max_history = DEFAULT_CHAT_HISTORY
setattr(ctx, 'specific_max_history', specific_max_history)
if ctx.logger:
ctx.logger.debug(f"{ctx.get_receiver()} 使用特定历史限制: {specific_max_history}")
# 处理引用图片情况
@@ -150,10 +154,128 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool:
ctx.logger.info(f"【发送内容】将以下消息发送给AI: \n{q_with_info}")
# 调用AI模型传递特定历史限制
tools = None
tool_handler = None
if ctx.robot and getattr(ctx.robot, 'message_summary', None):
chat_id = ctx.get_receiver()
message_summary = ctx.robot.message_summary
search_history_tool = {
"type": "function",
"function": {
"name": "search_chat_history",
"description": (
"Search recent conversation history for specific keywords. "
"Returns at most 20 recent segments, each including the matched message "
"with five surrounding messages (if available) and timestamps. "
"Use this tool when you need precise historical context."
),
"parameters": {
"type": "object",
"properties": {
"keywords": {
"type": "array",
"description": "List of keywords to search for in message content.",
"items": {"type": "string"},
"minItems": 1
},
"query": {
"type": "string",
"description": "Optional free-form string; will be split into keywords by whitespace."
},
"context_window": {
"type": "integer",
"description": "How many messages before and after each match to include (default 5, max 10).",
"minimum": 0,
"maximum": 10
},
"max_results": {
"type": "integer",
"description": "Maximum number of result segments to return (default 20).",
"minimum": 1,
"maximum": 20
}
},
"additionalProperties": False
}
}
}
def handle_tool_call(tool_name: str, arguments: Dict[str, Any]) -> str:
if tool_name != "search_chat_history":
return json.dumps({"error": f"Unknown tool '{tool_name}'"}, ensure_ascii=False)
try:
keywords = arguments.get("keywords", [])
if isinstance(keywords, str):
keywords = [keywords]
elif not isinstance(keywords, list):
keywords = []
query = arguments.get("query")
if isinstance(query, str) and query.strip():
query_keywords = [segment for segment in query.strip().split() if segment]
keywords.extend(query_keywords)
cleaned_keywords = []
for kw in keywords:
if kw is None:
continue
kw_str = str(kw).strip()
if kw_str:
cleaned_keywords.append(kw_str)
# 去重同时保持顺序
seen = set()
deduped_keywords = []
for kw in cleaned_keywords:
lower_kw = kw.lower()
if lower_kw not in seen:
seen.add(lower_kw)
deduped_keywords.append(kw)
if not deduped_keywords:
return json.dumps({"error": "No valid keywords provided.", "results": []}, ensure_ascii=False)
context_window = arguments.get("context_window", 5)
max_results = arguments.get("max_results", 20)
search_results = message_summary.search_messages_with_context(
chat_id=chat_id,
keywords=deduped_keywords,
context_window=context_window,
max_groups=max_results
)
response_payload = {
"results": search_results,
"returned_groups": len(search_results),
"keywords": deduped_keywords
}
if not search_results:
response_payload["notice"] = "No messages matched the provided keywords."
return json.dumps(response_payload, ensure_ascii=False)
except Exception as tool_exc:
if ctx.logger:
ctx.logger.error(f"搜索历史工具调用失败: {tool_exc}", exc_info=True)
return json.dumps(
{"error": f"Search failed: {tool_exc.__class__.__name__}"},
ensure_ascii=False
)
tools = [search_history_tool]
tool_handler = handle_tool_call
rsp = chat_model.get_answer(
question=q_with_info,
wxid=ctx.get_receiver(),
specific_max_history=specific_max_history
specific_max_history=specific_max_history,
tools=tools,
tool_handler=tool_handler,
tool_max_iterations=10
)
if rsp:

View File

@@ -223,6 +223,101 @@ class MessageSummary:
return messages
def search_messages_with_context(self, chat_id, keywords, context_window=5, max_groups=20):
"""根据关键词搜索消息,返回包含前后上下文的结果
Args:
chat_id (str): 聊天ID群ID或用户ID
keywords (Union[str, list[str]]): 需要搜索的关键词或关键词列表
context_window (int): 每条匹配消息前后额外提供的消息数量
max_groups (int): 返回的最多结果组数(按时间倒序,优先最新消息)
Returns:
list[dict]: 搜索结果列表,每个元素包含匹配关键词、锚点消息及上下文消息
"""
if not keywords:
return []
if isinstance(keywords, str):
keywords = [keywords]
normalized_keywords = []
for kw in keywords:
if kw is None:
continue
kw_str = str(kw).strip()
if kw_str:
normalized_keywords.append((kw_str, kw_str.lower()))
if not normalized_keywords:
return []
try:
context_window = int(context_window)
except (TypeError, ValueError):
context_window = 5
context_window = max(0, min(context_window, 10)) # 限制上下文窗口大小,避免过长
try:
max_groups = int(max_groups)
except (TypeError, ValueError):
max_groups = 20
max_groups = max(1, min(max_groups, 20))
messages = self.get_messages(chat_id)
if not messages:
return []
results = []
total_messages = len(messages)
used_indices = set()
for idx in range(total_messages - 1, -1, -1):
message = messages[idx]
content = message.get("content", "")
if not content:
continue
lower_content = content.lower()
matched_keywords = [orig for orig, lower in normalized_keywords if lower in lower_content]
if not matched_keywords:
continue
if idx in used_indices:
continue
start = max(0, idx - context_window)
end = min(total_messages, idx + context_window + 1)
segment_messages = []
for pos in range(start, end):
msg = messages[pos]
segment_messages.append({
"time": msg.get("time"),
"sender": msg.get("sender"),
"sender_wxid": msg.get("sender_wxid"),
"content": msg.get("content"),
"relative_offset": pos - idx,
"is_match": pos == idx
})
results.append({
"matched_keywords": matched_keywords,
"anchor_index": idx,
"anchor_time": message.get("time"),
"anchor_sender": message.get("sender"),
"anchor_sender_wxid": message.get("sender_wxid"),
"messages": segment_messages
})
for off in range(start, end):
used_indices.add(off)
if len(results) >= max_groups:
break
return results
def _basic_summarize(self, messages):
"""基本的消息总结逻辑不使用AI