function call

This commit is contained in:
zihanjian
2025-10-13 14:42:39 +08:00
parent 16e62c8eec
commit 34055175bc
6 changed files with 467 additions and 66 deletions

View File

@@ -223,6 +223,101 @@ class MessageSummary:
return messages
def search_messages_with_context(self, chat_id, keywords, context_window=5, max_groups=20):
"""根据关键词搜索消息,返回包含前后上下文的结果
Args:
chat_id (str): 聊天ID群ID或用户ID
keywords (Union[str, list[str]]): 需要搜索的关键词或关键词列表
context_window (int): 每条匹配消息前后额外提供的消息数量
max_groups (int): 返回的最多结果组数(按时间倒序,优先最新消息)
Returns:
list[dict]: 搜索结果列表,每个元素包含匹配关键词、锚点消息及上下文消息
"""
if not keywords:
return []
if isinstance(keywords, str):
keywords = [keywords]
normalized_keywords = []
for kw in keywords:
if kw is None:
continue
kw_str = str(kw).strip()
if kw_str:
normalized_keywords.append((kw_str, kw_str.lower()))
if not normalized_keywords:
return []
try:
context_window = int(context_window)
except (TypeError, ValueError):
context_window = 5
context_window = max(0, min(context_window, 10)) # 限制上下文窗口大小,避免过长
try:
max_groups = int(max_groups)
except (TypeError, ValueError):
max_groups = 20
max_groups = max(1, min(max_groups, 20))
messages = self.get_messages(chat_id)
if not messages:
return []
results = []
total_messages = len(messages)
used_indices = set()
for idx in range(total_messages - 1, -1, -1):
message = messages[idx]
content = message.get("content", "")
if not content:
continue
lower_content = content.lower()
matched_keywords = [orig for orig, lower in normalized_keywords if lower in lower_content]
if not matched_keywords:
continue
if idx in used_indices:
continue
start = max(0, idx - context_window)
end = min(total_messages, idx + context_window + 1)
segment_messages = []
for pos in range(start, end):
msg = messages[pos]
segment_messages.append({
"time": msg.get("time"),
"sender": msg.get("sender"),
"sender_wxid": msg.get("sender_wxid"),
"content": msg.get("content"),
"relative_offset": pos - idx,
"is_match": pos == idx
})
results.append({
"matched_keywords": matched_keywords,
"anchor_index": idx,
"anchor_time": message.get("time"),
"anchor_sender": message.get("sender"),
"anchor_sender_wxid": message.get("sender_wxid"),
"messages": segment_messages
})
for off in range(start, end):
used_indices.add(off)
if len(results) >= max_groups:
break
return results
def _basic_summarize(self, messages):
"""基本的消息总结逻辑不使用AI