diff --git a/commands/handlers.py b/commands/handlers.py index 171b244..69c83aa 100644 --- a/commands/handlers.py +++ b/commands/handlers.py @@ -229,6 +229,34 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool: } } + time_window_tool = { + "type": "function", + "function": { + "name": "fetch_chat_history_time_window", + "description": ( + "Fetch historical messages that occurred between two timestamps. " + "Provide precise start_time and end_time (e.g., 2025-05-01 08:00:00). " + "If start_time is later than end_time they will be swapped. " + "Only messages beyond the most recent 30 items are considered." + ), + "parameters": { + "type": "object", + "properties": { + "start_time": { + "type": "string", + "description": "Start timestamp (supports formats like YYYY-MM-DD HH:MM[:SS])." + }, + "end_time": { + "type": "string", + "description": "End timestamp (supports formats like YYYY-MM-DD HH:MM[:SS])." + } + }, + "required": ["start_time", "end_time"], + "additionalProperties": False + } + } + } + def handle_tool_call(tool_name: str, arguments: Dict[str, Any]) -> str: try: if tool_name == "search_chat_history": @@ -369,6 +397,43 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool: return json.dumps(response_payload, ensure_ascii=False) + elif tool_name == "fetch_chat_history_time_window": + if "start_time" not in arguments or "end_time" not in arguments: + return json.dumps({"error": "start_time and end_time are required."}, ensure_ascii=False) + + start_time = arguments.get("start_time") + end_time = arguments.get("end_time") + + print(f"[fetch_chat_history_time_window] chat_id={chat_id}, start_time={start_time}, end_time={end_time}") + if ctx.logger: + ctx.logger.info( + f"[fetch_chat_history_time_window] start_time={start_time}, end_time={end_time}" + ) + + time_lines = message_summary.get_messages_by_time_window( + chat_id=chat_id, + start_time=start_time, + end_time=end_time + ) + + response_payload = { + "start_time": start_time, + "end_time": end_time, + "messages": time_lines, + "returned_count": len(time_lines) + } + + print(f"[fetch_chat_history_time_window] returned_count={response_payload['returned_count']}") + if ctx.logger: + ctx.logger.info( + f"[fetch_chat_history_time_window] returned_count={response_payload['returned_count']}" + ) + + if response_payload["returned_count"] == 0: + response_payload["notice"] = "No messages found within the requested time window." + + return json.dumps(response_payload, ensure_ascii=False) + else: return json.dumps({"error": f"Unknown tool '{tool_name}'"}, ensure_ascii=False) @@ -380,7 +445,7 @@ def handle_chitchat(ctx: 'MessageContext', match: Optional[Match]) -> bool: ensure_ascii=False ) - tools = [search_history_tool, range_history_tool] + tools = [search_history_tool, range_history_tool, time_window_tool] tool_handler = handle_tool_call rsp = chat_model.get_answer( diff --git a/function/func_summary.py b/function/func_summary.py index 278b1eb..1912e2f 100644 --- a/function/func_summary.py +++ b/function/func_summary.py @@ -2,6 +2,7 @@ import logging import time +import datetime import re from collections import deque # from threading import Lock # 不再需要锁,使用SQLite的事务机制 @@ -608,3 +609,90 @@ class MessageSummary: self.LOG.debug(f"记录消息 (来源: {source_info}, 类型: {'群聊' if msg.from_group() else '私聊'}): '[{current_time_str}]{sender_name}({sender_wxid}): {content_to_record}' (来自 msg.id={msg.id})") # 调用 record_message 时传入 sender_wxid self.record_message(chat_id, sender_name, sender_wxid, content_to_record, current_time_str) + @staticmethod + def _parse_datetime(dt_value): + """解析多种常见时间格式""" + if isinstance(dt_value, datetime.datetime): + return dt_value + if not dt_value: + return None + + candidates = [ + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M", + "%Y/%m/%d %H:%M:%S", + "%Y/%m/%d %H:%M", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%dT%H:%M" + ] + dt_str = str(dt_value).strip() + for fmt in candidates: + try: + return datetime.datetime.strptime(dt_str, fmt) + except (ValueError, TypeError): + continue + return None + + def get_messages_by_time_window( + self, + chat_id, + start_time, + end_time, + exclude_recent=30, + max_messages=500 + ): + """根据时间窗口获取消息 + + Args: + chat_id (str): 聊天ID + start_time (Union[str, datetime]): 起始时间 + end_time (Union[str, datetime]): 结束时间 + exclude_recent (int): 跳过最新的消息数量 + max_messages (int): 返回的最大消息数 + + Returns: + list[str]: 已格式化的消息行 + """ + start_dt = self._parse_datetime(start_time) + end_dt = self._parse_datetime(end_time) + if not start_dt or not end_dt: + return [] + + try: + max_messages = int(max_messages) + except (TypeError, ValueError): + max_messages = 500 + max_messages = max(1, min(max_messages, 500)) + + messages = self.get_messages(chat_id) + if not messages: + return [] + + total_messages = len(messages) + cutoff_index = total_messages - max(exclude_recent, 0) + if cutoff_index <= 0: + return [] + + collected = [] + # 确保 start <= end + if start_dt > end_dt: + start_dt, end_dt = end_dt, start_dt + + for idx in range(cutoff_index - 1, -1, -1): + msg = messages[idx] + content = msg.get("content") + if self._is_internal_tool_message(content): + continue + + time_str = msg.get("time") + dt = self._parse_datetime(time_str) + if not dt: + continue + + if start_dt <= dt <= end_dt: + collected.append(f"{time_str} {msg.get('sender')} {content}") + if len(collected) >= max_messages: + break + + collected.reverse() + return collected