From cea7fb7490c53454602bf05955a0e9f059bcf0fd Mon Sep 17 00:00:00 2001 From: zhayujie Date: Sat, 7 Feb 2026 20:42:41 +0800 Subject: [PATCH] fix: add intelligent context cleanup #2663 --- agent/protocol/agent.py | 62 ++++++++++++++--- agent/protocol/agent_stream.py | 123 ++++++++++++++++++++++++++++++++- 2 files changed, 171 insertions(+), 14 deletions(-) diff --git a/agent/protocol/agent.py b/agent/protocol/agent.py index d477c15..7dcf73d 100644 --- a/agent/protocol/agent.py +++ b/agent/protocol/agent.py @@ -247,27 +247,67 @@ class Agent: def _estimate_message_tokens(self, message: dict) -> int: """ - Estimate token count for a message using chars/4 heuristic. - This is a conservative estimate (tends to overestimate). + Estimate token count for a message. + + Uses chars/3 for Chinese-heavy content and chars/4 for ASCII-heavy content, + plus per-block overhead for tool_use / tool_result structures. :param message: Message dict with 'role' and 'content' :return: Estimated token count """ content = message.get('content', '') if isinstance(content, str): - return max(1, len(content) // 4) + return max(1, self._estimate_text_tokens(content)) elif isinstance(content, list): - # Handle multi-part content (text + images) - total_chars = 0 + total_tokens = 0 for part in content: - if isinstance(part, dict) and part.get('type') == 'text': - total_chars += len(part.get('text', '')) - elif isinstance(part, dict) and part.get('type') == 'image': - # Estimate images as ~1200 tokens - total_chars += 4800 - return max(1, total_chars // 4) + if not isinstance(part, dict): + continue + block_type = part.get('type', '') + if block_type == 'text': + total_tokens += self._estimate_text_tokens(part.get('text', '')) + elif block_type == 'image': + total_tokens += 1200 + elif block_type == 'tool_use': + # tool_use has id + name + input (JSON-encoded) + total_tokens += 50 # overhead for structure + input_data = part.get('input', {}) + if isinstance(input_data, dict): + import json + input_str = json.dumps(input_data, ensure_ascii=False) + total_tokens += self._estimate_text_tokens(input_str) + elif block_type == 'tool_result': + # tool_result has tool_use_id + content + total_tokens += 30 # overhead for structure + result_content = part.get('content', '') + if isinstance(result_content, str): + total_tokens += self._estimate_text_tokens(result_content) + else: + # Unknown block type, estimate conservatively + total_tokens += 10 + return max(1, total_tokens) return 1 + @staticmethod + def _estimate_text_tokens(text: str) -> int: + """ + Estimate token count for a text string. + + Chinese / CJK characters typically use ~1.5 tokens each, + while ASCII uses ~0.25 tokens per char (4 chars/token). + We use a weighted average based on the character mix. + + :param text: Input text + :return: Estimated token count + """ + if not text: + return 0 + # Count non-ASCII characters (CJK, emoji, etc.) + non_ascii = sum(1 for c in text if ord(c) > 127) + ascii_count = len(text) - non_ascii + # CJK chars: ~1.5 tokens each; ASCII: ~0.25 tokens per char + return int(non_ascii * 1.5 + ascii_count * 0.25) + 1 + def _find_tool(self, tool_name: str): """Find and return a tool with the specified name""" for tool in self.tools: diff --git a/agent/protocol/agent_stream.py b/agent/protocol/agent_stream.py index d7664b2..4c37892 100644 --- a/agent/protocol/agent_stream.py +++ b/agent/protocol/agent_stream.py @@ -479,7 +479,8 @@ class AgentStreamExecutor: return final_response - def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3) -> Tuple[str, List[Dict]]: + def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3, + _overflow_retry: bool = False) -> Tuple[str, List[Dict]]: """ Call LLM with streaming and automatic retry on errors @@ -487,6 +488,7 @@ class AgentStreamExecutor: retry_on_empty: Whether to retry once if empty response is received retry_count: Current retry attempt (internal use) max_retries: Maximum number of retries for API errors + _overflow_retry: Internal flag indicating this is a retry after context overflow Returns: (response_text, tool_calls) @@ -638,10 +640,23 @@ class AgentStreamExecutor: if is_context_overflow or is_message_format_error: error_type = "context overflow" if is_context_overflow else "message format error" logger.error(f"💥 {error_type} detected: {e}") - # Clear message history to recover + + # Strategy: try aggressive trimming first, only clear as last resort + if is_context_overflow and not _overflow_retry: + trimmed = self._aggressive_trim_for_overflow() + if trimmed: + logger.warning("🔄 Aggressively trimmed context, retrying...") + return self._call_llm_stream( + retry_on_empty=retry_on_empty, + retry_count=retry_count, + max_retries=max_retries, + _overflow_retry=True + ) + + # Aggressive trim didn't help or this is a message format error + # -> clear everything logger.warning("🔄 Clearing conversation history to recover") self.messages.clear() - # Raise special exception with user-friendly message if is_context_overflow: raise Exception( "抱歉,对话历史过长导致上下文溢出。我已清空历史记录,请重新描述你的需求。" @@ -1015,6 +1030,108 @@ class AgentStreamExecutor: if truncated_count > 0: logger.info(f"📎 Truncated {truncated_count} historical tool result(s) to {MAX_HISTORY_RESULT_CHARS} chars") + def _aggressive_trim_for_overflow(self) -> bool: + """ + Aggressively trim context when a real overflow error is returned by the API. + + This method goes beyond normal _trim_messages by: + 1. Truncating all tool results (including current turn) to a small limit + 2. Keeping only the last 5 complete conversation turns + 3. Truncating overly long user messages + + Returns: + True if messages were trimmed (worth retrying), False if nothing left to trim + """ + if not self.messages: + return False + + original_count = len(self.messages) + + # Step 1: Aggressively truncate ALL tool results to 5K chars + AGGRESSIVE_LIMIT = 10000 + truncated = 0 + for msg in self.messages: + content = msg.get("content", []) + if not isinstance(content, list): + continue + for block in content: + if not isinstance(block, dict): + continue + # Truncate tool_result blocks + if block.get("type") == "tool_result": + result_str = block.get("content", "") + if isinstance(result_str, str) and len(result_str) > AGGRESSIVE_LIMIT: + block["content"] = ( + result_str[:AGGRESSIVE_LIMIT] + + f"\n\n[Truncated for context recovery: " + f"{len(result_str)} -> {AGGRESSIVE_LIMIT} chars]" + ) + truncated += 1 + # Truncate tool_use input blocks (e.g. large write content) + if block.get("type") == "tool_use" and isinstance(block.get("input"), dict): + input_str = json.dumps(block["input"], ensure_ascii=False) + if len(input_str) > AGGRESSIVE_LIMIT: + # Keep only a summary of the input + for key, val in block["input"].items(): + if isinstance(val, str) and len(val) > 1000: + block["input"][key] = ( + val[:1000] + + f"... [truncated {len(val)} chars]" + ) + truncated += 1 + + # Step 2: Truncate overly long user text messages (e.g. pasted content) + USER_MSG_LIMIT = 10000 + for msg in self.messages: + if msg.get("role") != "user": + continue + content = msg.get("content", []) + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + if len(text) > USER_MSG_LIMIT: + block["text"] = ( + text[:USER_MSG_LIMIT] + + f"\n\n[Message truncated for context recovery: " + f"{len(text)} -> {USER_MSG_LIMIT} chars]" + ) + truncated += 1 + elif isinstance(content, str) and len(content) > USER_MSG_LIMIT: + msg["content"] = ( + content[:USER_MSG_LIMIT] + + f"\n\n[Message truncated for context recovery: " + f"{len(content)} -> {USER_MSG_LIMIT} chars]" + ) + truncated += 1 + + # Step 3: Keep only the last 5 complete turns + turns = self._identify_complete_turns() + if len(turns) > 5: + kept_turns = turns[-5:] + new_messages = [] + for turn in kept_turns: + new_messages.extend(turn["messages"]) + removed = len(turns) - 5 + self.messages[:] = new_messages + logger.info( + f"🔧 Aggressive trim: removed {removed} old turns, " + f"truncated {truncated} large blocks, " + f"{original_count} -> {len(self.messages)} messages" + ) + return True + + if truncated > 0: + logger.info( + f"🔧 Aggressive trim: truncated {truncated} large blocks " + f"(no turns removed, only {len(turns)} turn(s) left)" + ) + return True + + # Nothing left to trim + logger.warning("🔧 Aggressive trim: nothing to trim, will clear history") + return False + def _trim_messages(self): """ 智能清理消息历史,保持对话完整性