mirror of
https://github.com/Zippland/Bubbles.git
synced 2026-01-19 01:21:15 +08:00
function call
This commit is contained in:
@@ -7,6 +7,7 @@ import base64
|
||||
import os
|
||||
from datetime import datetime
|
||||
import time # 引入 time 模块
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, APIError, AuthenticationError, OpenAI
|
||||
@@ -25,7 +26,7 @@ class ChatGPT():
|
||||
proxy = conf.get("proxy")
|
||||
prompt = conf.get("prompt")
|
||||
self.model = conf.get("model", "gpt-3.5-turbo")
|
||||
self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置,默认10条
|
||||
self.max_history_messages = conf.get("max_history_messages", 30) # 默认读取最近30条历史
|
||||
self.LOG = logging.getLogger("ChatGPT")
|
||||
|
||||
# 存储传入的实例和wxid
|
||||
@@ -56,7 +57,17 @@ class ChatGPT():
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_answer(self, question: str, wxid: str, system_prompt_override=None, specific_max_history=None) -> str:
|
||||
def get_answer(
|
||||
self,
|
||||
question: str,
|
||||
wxid: str,
|
||||
system_prompt_override=None,
|
||||
specific_max_history=None,
|
||||
tools=None,
|
||||
tool_handler=None,
|
||||
tool_choice=None,
|
||||
tool_max_iterations: int = 10
|
||||
) -> str:
|
||||
# 获取并格式化数据库历史记录
|
||||
api_messages = []
|
||||
|
||||
@@ -104,37 +115,123 @@ class ChatGPT():
|
||||
if question: # 确保问题非空
|
||||
api_messages.append({"role": "user", "content": question})
|
||||
|
||||
rsp = ""
|
||||
if tools and not tool_handler:
|
||||
# 如果提供了工具但没有处理器,则忽略工具以避免陷入死循环
|
||||
self.LOG.warning("tools 提供但没有 tool_handler,忽略工具定义。")
|
||||
tools = None
|
||||
|
||||
try:
|
||||
# 使用格式化后的 api_messages
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": api_messages # 使用从数据库构建的消息列表
|
||||
}
|
||||
|
||||
# 只有非o系列模型才设置temperature
|
||||
if not self.model.startswith("o"):
|
||||
params["temperature"] = 0.2
|
||||
|
||||
ret = self.client.chat.completions.create(**params)
|
||||
rsp = ret.choices[0].message.content
|
||||
rsp = rsp[2:] if rsp.startswith("\n\n") else rsp
|
||||
rsp = rsp.replace("\n\n", "\n")
|
||||
response_text = self._execute_with_tools(
|
||||
api_messages=api_messages,
|
||||
tools=tools,
|
||||
tool_handler=tool_handler,
|
||||
tool_choice=tool_choice,
|
||||
tool_max_iterations=tool_max_iterations
|
||||
)
|
||||
return response_text
|
||||
|
||||
except AuthenticationError:
|
||||
self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确")
|
||||
rsp = "API认证失败"
|
||||
return "API认证失败"
|
||||
except APIConnectionError:
|
||||
self.LOG.error("无法连接到 OpenAI API,请检查网络连接")
|
||||
rsp = "网络连接错误"
|
||||
return "网络连接错误"
|
||||
except APIError as e1:
|
||||
self.LOG.error(f"OpenAI API 返回了错误:{str(e1)}")
|
||||
rsp = f"API错误: {str(e1)}"
|
||||
return f"API错误: {str(e1)}"
|
||||
except Exception as e0:
|
||||
self.LOG.error(f"发生未知错误:{str(e0)}")
|
||||
rsp = "发生未知错误"
|
||||
self.LOG.error(f"发生未知错误:{str(e0)}", exc_info=True)
|
||||
return "发生未知错误"
|
||||
|
||||
return rsp
|
||||
def _execute_with_tools(
|
||||
self,
|
||||
api_messages,
|
||||
tools=None,
|
||||
tool_handler=None,
|
||||
tool_choice=None,
|
||||
tool_max_iterations: int = 10
|
||||
) -> str:
|
||||
"""执行带工具调用的对话逻辑"""
|
||||
iterations = 0
|
||||
params_base = {"model": self.model}
|
||||
|
||||
# 只有非o系列模型才设置temperature
|
||||
if not self.model.startswith("o"):
|
||||
params_base["temperature"] = 0.2
|
||||
|
||||
# 确保工具参数格式正确
|
||||
runtime_tools = tools if tools and isinstance(tools, list) else None
|
||||
runtime_tool_choice = tool_choice
|
||||
|
||||
while True:
|
||||
params = dict(params_base)
|
||||
params["messages"] = api_messages
|
||||
if runtime_tools:
|
||||
params["tools"] = runtime_tools
|
||||
if runtime_tool_choice:
|
||||
params["tool_choice"] = runtime_tool_choice
|
||||
|
||||
ret = self.client.chat.completions.create(**params)
|
||||
choice = ret.choices[0]
|
||||
message = choice.message
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
if (
|
||||
runtime_tools
|
||||
and message
|
||||
and getattr(message, "tool_calls", None)
|
||||
and finish_reason == "tool_calls"
|
||||
and tool_handler
|
||||
):
|
||||
iterations += 1
|
||||
api_messages.append({
|
||||
"role": "assistant",
|
||||
"content": message.content or "",
|
||||
"tool_calls": message.tool_calls
|
||||
})
|
||||
|
||||
if tool_max_iterations is not None and iterations > max(tool_max_iterations, 0):
|
||||
api_messages.append({
|
||||
"role": "system",
|
||||
"content": "你已经达到可使用搜索历史工具的最大次数,请停止继续调用该工具,直接根据目前掌握的信息给出最终回答。"
|
||||
})
|
||||
runtime_tool_choice = "none"
|
||||
continue
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
raw_arguments = tool_call.function.arguments or "{}"
|
||||
try:
|
||||
parsed_arguments = json.loads(raw_arguments)
|
||||
except json.JSONDecodeError:
|
||||
parsed_arguments = {"_raw": raw_arguments}
|
||||
|
||||
try:
|
||||
tool_output = tool_handler(tool_name, parsed_arguments)
|
||||
except Exception as handler_exc:
|
||||
self.LOG.error(f"工具 {tool_name} 执行失败: {handler_exc}", exc_info=True)
|
||||
tool_output = json.dumps(
|
||||
{"error": f"{tool_name} failed: {handler_exc.__class__.__name__}"},
|
||||
ensure_ascii=False
|
||||
)
|
||||
|
||||
if not isinstance(tool_output, str):
|
||||
tool_output = json.dumps(tool_output, ensure_ascii=False)
|
||||
|
||||
api_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_output
|
||||
})
|
||||
|
||||
runtime_tool_choice = None
|
||||
continue
|
||||
|
||||
response_text = message.content if message and message.content else ""
|
||||
if response_text.startswith("\n\n"):
|
||||
response_text = response_text[2:]
|
||||
response_text = response_text.replace("\n\n", "\n")
|
||||
return response_text
|
||||
|
||||
def encode_image_to_base64(self, image_path: str) -> str:
|
||||
"""将图片文件转换为Base64编码
|
||||
@@ -226,4 +323,4 @@ if __name__ == "__main__":
|
||||
# --- 测试代码需要调整 ---
|
||||
# 需要模拟 MessageSummary 和提供 bot_wxid 才能测试
|
||||
print("请注意:直接运行此文件进行测试需要模拟 MessageSummary 并提供 bot_wxid。")
|
||||
pass # 避免直接运行时出错
|
||||
pass # 避免直接运行时出错
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import time # 引入 time 模块
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, APIError, AuthenticationError, OpenAI
|
||||
@@ -23,7 +24,7 @@ class DeepSeek():
|
||||
prompt = conf.get("prompt")
|
||||
self.model = conf.get("model", "deepseek-chat")
|
||||
# 读取最大历史消息数配置
|
||||
self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置,默认10条
|
||||
self.max_history_messages = conf.get("max_history_messages", 30) # 默认使用最近30条历史
|
||||
self.LOG = logging.getLogger("DeepSeek")
|
||||
|
||||
# 存储传入的实例和wxid
|
||||
@@ -52,7 +53,17 @@ class DeepSeek():
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_answer(self, question: str, wxid: str, system_prompt_override=None, specific_max_history=None) -> str:
|
||||
def get_answer(
|
||||
self,
|
||||
question: str,
|
||||
wxid: str,
|
||||
system_prompt_override=None,
|
||||
specific_max_history=None,
|
||||
tools=None,
|
||||
tool_handler=None,
|
||||
tool_choice=None,
|
||||
tool_max_iterations: int = 10
|
||||
) -> str:
|
||||
# 获取并格式化数据库历史记录
|
||||
api_messages = []
|
||||
|
||||
@@ -100,27 +111,109 @@ class DeepSeek():
|
||||
if question:
|
||||
api_messages.append({"role": "user", "content": question})
|
||||
|
||||
if tools and not tool_handler:
|
||||
self.LOG.warning("tools 提供但未传入 tool_handler,忽略工具配置。")
|
||||
tools = None
|
||||
|
||||
try:
|
||||
# 使用格式化后的 api_messages
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages, # 使用构建的消息列表
|
||||
stream=False
|
||||
final_response = self._execute_with_tools(
|
||||
api_messages=api_messages,
|
||||
tools=tools,
|
||||
tool_handler=tool_handler,
|
||||
tool_choice=tool_choice,
|
||||
tool_max_iterations=tool_max_iterations
|
||||
)
|
||||
final_response = response.choices[0].message.content
|
||||
|
||||
|
||||
return final_response
|
||||
|
||||
except (APIConnectionError, APIError, AuthenticationError) as e1:
|
||||
self.LOG.error(f"DeepSeek API 返回了错误:{str(e1)}")
|
||||
return f"DeepSeek API 返回了错误:{str(e1)}"
|
||||
except Exception as e0:
|
||||
self.LOG.error(f"发生未知错误:{str(e0)}")
|
||||
self.LOG.error(f"发生未知错误:{str(e0)}", exc_info=True)
|
||||
return "抱歉,处理您的请求时出现了错误"
|
||||
|
||||
def _execute_with_tools(
|
||||
self,
|
||||
api_messages,
|
||||
tools=None,
|
||||
tool_handler=None,
|
||||
tool_choice=None,
|
||||
tool_max_iterations: int = 10
|
||||
) -> str:
|
||||
iterations = 0
|
||||
params_base = {"model": self.model, "stream": False}
|
||||
|
||||
runtime_tools = tools if tools and isinstance(tools, list) else None
|
||||
runtime_tool_choice = tool_choice
|
||||
|
||||
while True:
|
||||
params = dict(params_base)
|
||||
params["messages"] = api_messages
|
||||
if runtime_tools:
|
||||
params["tools"] = runtime_tools
|
||||
if runtime_tool_choice:
|
||||
params["tool_choice"] = runtime_tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
if (
|
||||
runtime_tools
|
||||
and message
|
||||
and getattr(message, "tool_calls", None)
|
||||
and finish_reason == "tool_calls"
|
||||
and tool_handler
|
||||
):
|
||||
iterations += 1
|
||||
api_messages.append({
|
||||
"role": "assistant",
|
||||
"content": message.content or "",
|
||||
"tool_calls": message.tool_calls
|
||||
})
|
||||
|
||||
if tool_max_iterations is not None and iterations > max(tool_max_iterations, 0):
|
||||
api_messages.append({
|
||||
"role": "system",
|
||||
"content": "你已经达到允许的最大搜索次数,请停止继续调用搜索工具,根据现有信息完成回答。"
|
||||
})
|
||||
runtime_tool_choice = "none"
|
||||
continue
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
raw_arguments = tool_call.function.arguments or "{}"
|
||||
try:
|
||||
parsed_arguments = json.loads(raw_arguments)
|
||||
except json.JSONDecodeError:
|
||||
parsed_arguments = {"_raw": raw_arguments}
|
||||
|
||||
try:
|
||||
tool_output = tool_handler(tool_name, parsed_arguments)
|
||||
except Exception as handler_exc:
|
||||
self.LOG.error(f"工具 {tool_name} 执行失败: {handler_exc}", exc_info=True)
|
||||
tool_output = json.dumps(
|
||||
{"error": f"{tool_name} failed: {handler_exc.__class__.__name__}"},
|
||||
ensure_ascii=False
|
||||
)
|
||||
|
||||
if not isinstance(tool_output, str):
|
||||
tool_output = json.dumps(tool_output, ensure_ascii=False)
|
||||
|
||||
api_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_output
|
||||
})
|
||||
|
||||
runtime_tool_choice = None
|
||||
continue
|
||||
|
||||
return message.content if message and message.content else ""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# --- 测试代码需要调整 ---
|
||||
print("请注意:直接运行此文件进行测试需要模拟 MessageSummary 并提供 bot_wxid。")
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -21,7 +21,7 @@ except ImportError:
|
||||
class Gemini:
|
||||
DEFAULT_MODEL = "gemini-1.5-pro-latest"
|
||||
DEFAULT_PROMPT = "You are a helpful assistant."
|
||||
DEFAULT_MAX_HISTORY = 15
|
||||
DEFAULT_MAX_HISTORY = 30
|
||||
SAFETY_SETTINGS = { # 默认安全设置 - 可根据需要调整或从配置加载
|
||||
safety_types.HarmCategory.HARM_CATEGORY_HARASSMENT: safety_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||||
safety_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: safety_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||||
@@ -206,7 +206,17 @@ class Gemini:
|
||||
|
||||
return rsp_text.strip()
|
||||
|
||||
def get_answer(self, question: str, wxid: str, system_prompt_override=None, specific_max_history=None) -> str:
|
||||
def get_answer(
|
||||
self,
|
||||
question: str,
|
||||
wxid: str,
|
||||
system_prompt_override=None,
|
||||
specific_max_history=None,
|
||||
tools=None,
|
||||
tool_handler=None,
|
||||
tool_choice=None,
|
||||
tool_max_iterations: int = 10
|
||||
) -> str:
|
||||
if not self._model:
|
||||
return "Gemini 模型未成功初始化,请检查配置和网络。"
|
||||
|
||||
@@ -214,6 +224,9 @@ class Gemini:
|
||||
self.LOG.warning(f"尝试为 wxid={wxid} 获取答案,但问题为空。")
|
||||
return "您没有提问哦。"
|
||||
|
||||
if tools:
|
||||
self.LOG.debug("Gemini 提供的实现暂不支持工具调用,请忽略 tools 参数。")
|
||||
|
||||
# 1. 准备历史消息
|
||||
contents = []
|
||||
if self.message_summary and self.bot_wxid:
|
||||
@@ -395,4 +408,4 @@ if __name__ == "__main__":
|
||||
else:
|
||||
print("\n--- Gemini 初始化失败,跳过图片描述测试 ---")
|
||||
|
||||
print("\n--- Gemini 本地测试结束 ---")
|
||||
print("\n--- Gemini 本地测试结束 ---")
|
||||
|
||||
@@ -7,7 +7,7 @@ from .context import MessageContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ROUTING_HISTORY_LIMIT = 10
|
||||
ROUTING_HISTORY_LIMIT = 30
|
||||
CHAT_HISTORY_MIN = 10
|
||||
CHAT_HISTORY_MAX = 300
|
||||
|
||||
@@ -85,8 +85,7 @@ class AIRouter:
|
||||
|
||||
1. 如果用户只是聊天或者不匹配任何功能,返回:
|
||||
{
|
||||
"action_type": "chat",
|
||||
"history_messages": 25 # 你认为闲聊需要的历史条数,介于10-300之间
|
||||
"action_type": "chat"
|
||||
}
|
||||
|
||||
2.如果用户需要使用上述功能之一,返回:
|
||||
@@ -99,14 +98,13 @@ class AIRouter:
|
||||
#### 示例:
|
||||
- 用户输入"提醒我下午3点开会" -> {"action_type": "function", "function_name": "reminder_set", "params": "下午3点开会"}
|
||||
- 用户输入"查看我的提醒" -> {"action_type": "function", "function_name": "reminder_list", "params": ""}
|
||||
- 用户输入"你好" -> {"action_type": "chat", "history_messages": 15}
|
||||
- 用户输入"你好" -> {"action_type": "chat"}
|
||||
- 用户输入"查一下Python教程" -> {"action_type": "function", "function_name": "perplexity_search", "params": "Python教程"}
|
||||
|
||||
#### 格式注意事项:
|
||||
1. action_type 只能是 "function" 或 "chat"
|
||||
2. 只返回JSON,无需其他解释
|
||||
3. function_name 必须完全匹配上述功能列表中的名称
|
||||
4. 当 action_type 是 "chat" 时,必须提供整数 history_messages 字段,范围为10-300
|
||||
"""
|
||||
return prompt
|
||||
|
||||
@@ -172,17 +170,7 @@ class AIRouter:
|
||||
self.logger.warning(f"AI路由器:未知的功能名 - {function_name}")
|
||||
return False, None
|
||||
|
||||
if action_type == "chat":
|
||||
history_value = decision.get("history_messages")
|
||||
try:
|
||||
history_value = int(history_value)
|
||||
except (TypeError, ValueError):
|
||||
history_value = CHAT_HISTORY_MIN
|
||||
history_value = max(CHAT_HISTORY_MIN, min(CHAT_HISTORY_MAX, history_value))
|
||||
decision["history_messages"] = history_value
|
||||
self.logger.info(f"AI路由决策: {decision}")
|
||||
else:
|
||||
self.logger.info(f"AI路由决策: {decision}")
|
||||
self.logger.info(f"AI路由决策: {decision}")
|
||||
return True, decision
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -245,14 +233,7 @@ class AIRouter:
|
||||
|
||||
# 如果是聊天,返回False让后续处理器处理
|
||||
if action_type == "chat":
|
||||
history_limit = decision.get("history_messages", CHAT_HISTORY_MIN)
|
||||
try:
|
||||
history_limit = int(history_limit)
|
||||
except (TypeError, ValueError):
|
||||
history_limit = CHAT_HISTORY_MIN
|
||||
history_limit = max(CHAT_HISTORY_MIN, min(CHAT_HISTORY_MAX, history_limit))
|
||||
setattr(ctx, 'specific_max_history', history_limit)
|
||||
self.logger.info(f"AI路由器:识别为聊天意图,交给聊天处理器,使用历史条数 {history_limit}")
|
||||
self.logger.info("AI路由器:识别为聊天意图,交给聊天处理器处理。")
|
||||
return False
|
||||
|
||||
# 如果是功能调用
|
||||
|
||||
@@ -9,6 +9,8 @@ from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from .context import MessageContext
|
||||
|
||||
DEFAULT_CHAT_HISTORY = 30
|
||||
|
||||
def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool:
|
||||
"""
|
||||
处理闲聊,调用AI模型生成回复
|
||||
@@ -39,8 +41,10 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool:
|
||||
specific_max_history = 10
|
||||
elif specific_max_history > 300:
|
||||
specific_max_history = 300
|
||||
setattr(ctx, 'specific_max_history', specific_max_history)
|
||||
if ctx.logger and specific_max_history is not None:
|
||||
if specific_max_history is None:
|
||||
specific_max_history = DEFAULT_CHAT_HISTORY
|
||||
setattr(ctx, 'specific_max_history', specific_max_history)
|
||||
if ctx.logger:
|
||||
ctx.logger.debug(f"为 {ctx.get_receiver()} 使用特定历史限制: {specific_max_history}")
|
||||
|
||||
# 处理引用图片情况
|
||||
@@ -150,10 +154,128 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool:
|
||||
ctx.logger.info(f"【发送内容】将以下消息发送给AI: \n{q_with_info}")
|
||||
|
||||
# 调用AI模型,传递特定历史限制
|
||||
tools = None
|
||||
tool_handler = None
|
||||
|
||||
if ctx.robot and getattr(ctx.robot, 'message_summary', None):
|
||||
chat_id = ctx.get_receiver()
|
||||
message_summary = ctx.robot.message_summary
|
||||
|
||||
search_history_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_chat_history",
|
||||
"description": (
|
||||
"Search recent conversation history for specific keywords. "
|
||||
"Returns at most 20 recent segments, each including the matched message "
|
||||
"with five surrounding messages (if available) and timestamps. "
|
||||
"Use this tool when you need precise historical context."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keywords": {
|
||||
"type": "array",
|
||||
"description": "List of keywords to search for in message content.",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 1
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Optional free-form string; will be split into keywords by whitespace."
|
||||
},
|
||||
"context_window": {
|
||||
"type": "integer",
|
||||
"description": "How many messages before and after each match to include (default 5, max 10).",
|
||||
"minimum": 0,
|
||||
"maximum": 10
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of result segments to return (default 20).",
|
||||
"minimum": 1,
|
||||
"maximum": 20
|
||||
}
|
||||
},
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def handle_tool_call(tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
if tool_name != "search_chat_history":
|
||||
return json.dumps({"error": f"Unknown tool '{tool_name}'"}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
keywords = arguments.get("keywords", [])
|
||||
if isinstance(keywords, str):
|
||||
keywords = [keywords]
|
||||
elif not isinstance(keywords, list):
|
||||
keywords = []
|
||||
|
||||
query = arguments.get("query")
|
||||
if isinstance(query, str) and query.strip():
|
||||
query_keywords = [segment for segment in query.strip().split() if segment]
|
||||
keywords.extend(query_keywords)
|
||||
|
||||
cleaned_keywords = []
|
||||
for kw in keywords:
|
||||
if kw is None:
|
||||
continue
|
||||
kw_str = str(kw).strip()
|
||||
if kw_str:
|
||||
cleaned_keywords.append(kw_str)
|
||||
|
||||
# 去重同时保持顺序
|
||||
seen = set()
|
||||
deduped_keywords = []
|
||||
for kw in cleaned_keywords:
|
||||
lower_kw = kw.lower()
|
||||
if lower_kw not in seen:
|
||||
seen.add(lower_kw)
|
||||
deduped_keywords.append(kw)
|
||||
|
||||
if not deduped_keywords:
|
||||
return json.dumps({"error": "No valid keywords provided.", "results": []}, ensure_ascii=False)
|
||||
|
||||
context_window = arguments.get("context_window", 5)
|
||||
max_results = arguments.get("max_results", 20)
|
||||
|
||||
search_results = message_summary.search_messages_with_context(
|
||||
chat_id=chat_id,
|
||||
keywords=deduped_keywords,
|
||||
context_window=context_window,
|
||||
max_groups=max_results
|
||||
)
|
||||
|
||||
response_payload = {
|
||||
"results": search_results,
|
||||
"returned_groups": len(search_results),
|
||||
"keywords": deduped_keywords
|
||||
}
|
||||
|
||||
if not search_results:
|
||||
response_payload["notice"] = "No messages matched the provided keywords."
|
||||
|
||||
return json.dumps(response_payload, ensure_ascii=False)
|
||||
except Exception as tool_exc:
|
||||
if ctx.logger:
|
||||
ctx.logger.error(f"搜索历史工具调用失败: {tool_exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"error": f"Search failed: {tool_exc.__class__.__name__}"},
|
||||
ensure_ascii=False
|
||||
)
|
||||
|
||||
tools = [search_history_tool]
|
||||
tool_handler = handle_tool_call
|
||||
|
||||
rsp = chat_model.get_answer(
|
||||
question=q_with_info,
|
||||
wxid=ctx.get_receiver(),
|
||||
specific_max_history=specific_max_history
|
||||
specific_max_history=specific_max_history,
|
||||
tools=tools,
|
||||
tool_handler=tool_handler,
|
||||
tool_max_iterations=10
|
||||
)
|
||||
|
||||
if rsp:
|
||||
|
||||
@@ -223,6 +223,101 @@ class MessageSummary:
|
||||
|
||||
return messages
|
||||
|
||||
def search_messages_with_context(self, chat_id, keywords, context_window=5, max_groups=20):
|
||||
"""根据关键词搜索消息,返回包含前后上下文的结果
|
||||
|
||||
Args:
|
||||
chat_id (str): 聊天ID(群ID或用户ID)
|
||||
keywords (Union[str, list[str]]): 需要搜索的关键词或关键词列表
|
||||
context_window (int): 每条匹配消息前后额外提供的消息数量
|
||||
max_groups (int): 返回的最多结果组数(按时间倒序,优先最新消息)
|
||||
|
||||
Returns:
|
||||
list[dict]: 搜索结果列表,每个元素包含匹配关键词、锚点消息及上下文消息
|
||||
"""
|
||||
if not keywords:
|
||||
return []
|
||||
|
||||
if isinstance(keywords, str):
|
||||
keywords = [keywords]
|
||||
|
||||
normalized_keywords = []
|
||||
for kw in keywords:
|
||||
if kw is None:
|
||||
continue
|
||||
kw_str = str(kw).strip()
|
||||
if kw_str:
|
||||
normalized_keywords.append((kw_str, kw_str.lower()))
|
||||
|
||||
if not normalized_keywords:
|
||||
return []
|
||||
|
||||
try:
|
||||
context_window = int(context_window)
|
||||
except (TypeError, ValueError):
|
||||
context_window = 5
|
||||
context_window = max(0, min(context_window, 10)) # 限制上下文窗口大小,避免过长
|
||||
|
||||
try:
|
||||
max_groups = int(max_groups)
|
||||
except (TypeError, ValueError):
|
||||
max_groups = 20
|
||||
max_groups = max(1, min(max_groups, 20))
|
||||
|
||||
messages = self.get_messages(chat_id)
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
results = []
|
||||
total_messages = len(messages)
|
||||
used_indices = set()
|
||||
|
||||
for idx in range(total_messages - 1, -1, -1):
|
||||
message = messages[idx]
|
||||
content = message.get("content", "")
|
||||
if not content:
|
||||
continue
|
||||
|
||||
lower_content = content.lower()
|
||||
matched_keywords = [orig for orig, lower in normalized_keywords if lower in lower_content]
|
||||
if not matched_keywords:
|
||||
continue
|
||||
|
||||
if idx in used_indices:
|
||||
continue
|
||||
|
||||
start = max(0, idx - context_window)
|
||||
end = min(total_messages, idx + context_window + 1)
|
||||
segment_messages = []
|
||||
|
||||
for pos in range(start, end):
|
||||
msg = messages[pos]
|
||||
segment_messages.append({
|
||||
"time": msg.get("time"),
|
||||
"sender": msg.get("sender"),
|
||||
"sender_wxid": msg.get("sender_wxid"),
|
||||
"content": msg.get("content"),
|
||||
"relative_offset": pos - idx,
|
||||
"is_match": pos == idx
|
||||
})
|
||||
|
||||
results.append({
|
||||
"matched_keywords": matched_keywords,
|
||||
"anchor_index": idx,
|
||||
"anchor_time": message.get("time"),
|
||||
"anchor_sender": message.get("sender"),
|
||||
"anchor_sender_wxid": message.get("sender_wxid"),
|
||||
"messages": segment_messages
|
||||
})
|
||||
|
||||
for off in range(start, end):
|
||||
used_indices.add(off)
|
||||
|
||||
if len(results) >= max_groups:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
def _basic_summarize(self, messages):
|
||||
"""基本的消息总结逻辑,不使用AI
|
||||
|
||||
|
||||
Reference in New Issue
Block a user