mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-14 08:16:32 +08:00
fix: add intelligent context cleanup #2663
This commit is contained in:
@@ -247,27 +247,67 @@ class Agent:
|
||||
|
||||
def _estimate_message_tokens(self, message: dict) -> int:
|
||||
"""
|
||||
Estimate token count for a message using chars/4 heuristic.
|
||||
This is a conservative estimate (tends to overestimate).
|
||||
Estimate token count for a message.
|
||||
|
||||
Uses chars/3 for Chinese-heavy content and chars/4 for ASCII-heavy content,
|
||||
plus per-block overhead for tool_use / tool_result structures.
|
||||
|
||||
:param message: Message dict with 'role' and 'content'
|
||||
:return: Estimated token count
|
||||
"""
|
||||
content = message.get('content', '')
|
||||
if isinstance(content, str):
|
||||
return max(1, len(content) // 4)
|
||||
return max(1, self._estimate_text_tokens(content))
|
||||
elif isinstance(content, list):
|
||||
# Handle multi-part content (text + images)
|
||||
total_chars = 0
|
||||
total_tokens = 0
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get('type') == 'text':
|
||||
total_chars += len(part.get('text', ''))
|
||||
elif isinstance(part, dict) and part.get('type') == 'image':
|
||||
# Estimate images as ~1200 tokens
|
||||
total_chars += 4800
|
||||
return max(1, total_chars // 4)
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
block_type = part.get('type', '')
|
||||
if block_type == 'text':
|
||||
total_tokens += self._estimate_text_tokens(part.get('text', ''))
|
||||
elif block_type == 'image':
|
||||
total_tokens += 1200
|
||||
elif block_type == 'tool_use':
|
||||
# tool_use has id + name + input (JSON-encoded)
|
||||
total_tokens += 50 # overhead for structure
|
||||
input_data = part.get('input', {})
|
||||
if isinstance(input_data, dict):
|
||||
import json
|
||||
input_str = json.dumps(input_data, ensure_ascii=False)
|
||||
total_tokens += self._estimate_text_tokens(input_str)
|
||||
elif block_type == 'tool_result':
|
||||
# tool_result has tool_use_id + content
|
||||
total_tokens += 30 # overhead for structure
|
||||
result_content = part.get('content', '')
|
||||
if isinstance(result_content, str):
|
||||
total_tokens += self._estimate_text_tokens(result_content)
|
||||
else:
|
||||
# Unknown block type, estimate conservatively
|
||||
total_tokens += 10
|
||||
return max(1, total_tokens)
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def _estimate_text_tokens(text: str) -> int:
|
||||
"""
|
||||
Estimate token count for a text string.
|
||||
|
||||
Chinese / CJK characters typically use ~1.5 tokens each,
|
||||
while ASCII uses ~0.25 tokens per char (4 chars/token).
|
||||
We use a weighted average based on the character mix.
|
||||
|
||||
:param text: Input text
|
||||
:return: Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
# Count non-ASCII characters (CJK, emoji, etc.)
|
||||
non_ascii = sum(1 for c in text if ord(c) > 127)
|
||||
ascii_count = len(text) - non_ascii
|
||||
# CJK chars: ~1.5 tokens each; ASCII: ~0.25 tokens per char
|
||||
return int(non_ascii * 1.5 + ascii_count * 0.25) + 1
|
||||
|
||||
def _find_tool(self, tool_name: str):
|
||||
"""Find and return a tool with the specified name"""
|
||||
for tool in self.tools:
|
||||
|
||||
@@ -479,7 +479,8 @@ class AgentStreamExecutor:
|
||||
|
||||
return final_response
|
||||
|
||||
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3) -> Tuple[str, List[Dict]]:
|
||||
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3,
|
||||
_overflow_retry: bool = False) -> Tuple[str, List[Dict]]:
|
||||
"""
|
||||
Call LLM with streaming and automatic retry on errors
|
||||
|
||||
@@ -487,6 +488,7 @@ class AgentStreamExecutor:
|
||||
retry_on_empty: Whether to retry once if empty response is received
|
||||
retry_count: Current retry attempt (internal use)
|
||||
max_retries: Maximum number of retries for API errors
|
||||
_overflow_retry: Internal flag indicating this is a retry after context overflow
|
||||
|
||||
Returns:
|
||||
(response_text, tool_calls)
|
||||
@@ -638,10 +640,23 @@ class AgentStreamExecutor:
|
||||
if is_context_overflow or is_message_format_error:
|
||||
error_type = "context overflow" if is_context_overflow else "message format error"
|
||||
logger.error(f"💥 {error_type} detected: {e}")
|
||||
# Clear message history to recover
|
||||
|
||||
# Strategy: try aggressive trimming first, only clear as last resort
|
||||
if is_context_overflow and not _overflow_retry:
|
||||
trimmed = self._aggressive_trim_for_overflow()
|
||||
if trimmed:
|
||||
logger.warning("🔄 Aggressively trimmed context, retrying...")
|
||||
return self._call_llm_stream(
|
||||
retry_on_empty=retry_on_empty,
|
||||
retry_count=retry_count,
|
||||
max_retries=max_retries,
|
||||
_overflow_retry=True
|
||||
)
|
||||
|
||||
# Aggressive trim didn't help or this is a message format error
|
||||
# -> clear everything
|
||||
logger.warning("🔄 Clearing conversation history to recover")
|
||||
self.messages.clear()
|
||||
# Raise special exception with user-friendly message
|
||||
if is_context_overflow:
|
||||
raise Exception(
|
||||
"抱歉,对话历史过长导致上下文溢出。我已清空历史记录,请重新描述你的需求。"
|
||||
@@ -1015,6 +1030,108 @@ class AgentStreamExecutor:
|
||||
if truncated_count > 0:
|
||||
logger.info(f"📎 Truncated {truncated_count} historical tool result(s) to {MAX_HISTORY_RESULT_CHARS} chars")
|
||||
|
||||
def _aggressive_trim_for_overflow(self) -> bool:
|
||||
"""
|
||||
Aggressively trim context when a real overflow error is returned by the API.
|
||||
|
||||
This method goes beyond normal _trim_messages by:
|
||||
1. Truncating all tool results (including current turn) to a small limit
|
||||
2. Keeping only the last 5 complete conversation turns
|
||||
3. Truncating overly long user messages
|
||||
|
||||
Returns:
|
||||
True if messages were trimmed (worth retrying), False if nothing left to trim
|
||||
"""
|
||||
if not self.messages:
|
||||
return False
|
||||
|
||||
original_count = len(self.messages)
|
||||
|
||||
# Step 1: Aggressively truncate ALL tool results to 5K chars
|
||||
AGGRESSIVE_LIMIT = 10000
|
||||
truncated = 0
|
||||
for msg in self.messages:
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
# Truncate tool_result blocks
|
||||
if block.get("type") == "tool_result":
|
||||
result_str = block.get("content", "")
|
||||
if isinstance(result_str, str) and len(result_str) > AGGRESSIVE_LIMIT:
|
||||
block["content"] = (
|
||||
result_str[:AGGRESSIVE_LIMIT]
|
||||
+ f"\n\n[Truncated for context recovery: "
|
||||
f"{len(result_str)} -> {AGGRESSIVE_LIMIT} chars]"
|
||||
)
|
||||
truncated += 1
|
||||
# Truncate tool_use input blocks (e.g. large write content)
|
||||
if block.get("type") == "tool_use" and isinstance(block.get("input"), dict):
|
||||
input_str = json.dumps(block["input"], ensure_ascii=False)
|
||||
if len(input_str) > AGGRESSIVE_LIMIT:
|
||||
# Keep only a summary of the input
|
||||
for key, val in block["input"].items():
|
||||
if isinstance(val, str) and len(val) > 1000:
|
||||
block["input"][key] = (
|
||||
val[:1000]
|
||||
+ f"... [truncated {len(val)} chars]"
|
||||
)
|
||||
truncated += 1
|
||||
|
||||
# Step 2: Truncate overly long user text messages (e.g. pasted content)
|
||||
USER_MSG_LIMIT = 10000
|
||||
for msg in self.messages:
|
||||
if msg.get("role") != "user":
|
||||
continue
|
||||
content = msg.get("content", [])
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
if len(text) > USER_MSG_LIMIT:
|
||||
block["text"] = (
|
||||
text[:USER_MSG_LIMIT]
|
||||
+ f"\n\n[Message truncated for context recovery: "
|
||||
f"{len(text)} -> {USER_MSG_LIMIT} chars]"
|
||||
)
|
||||
truncated += 1
|
||||
elif isinstance(content, str) and len(content) > USER_MSG_LIMIT:
|
||||
msg["content"] = (
|
||||
content[:USER_MSG_LIMIT]
|
||||
+ f"\n\n[Message truncated for context recovery: "
|
||||
f"{len(content)} -> {USER_MSG_LIMIT} chars]"
|
||||
)
|
||||
truncated += 1
|
||||
|
||||
# Step 3: Keep only the last 5 complete turns
|
||||
turns = self._identify_complete_turns()
|
||||
if len(turns) > 5:
|
||||
kept_turns = turns[-5:]
|
||||
new_messages = []
|
||||
for turn in kept_turns:
|
||||
new_messages.extend(turn["messages"])
|
||||
removed = len(turns) - 5
|
||||
self.messages[:] = new_messages
|
||||
logger.info(
|
||||
f"🔧 Aggressive trim: removed {removed} old turns, "
|
||||
f"truncated {truncated} large blocks, "
|
||||
f"{original_count} -> {len(self.messages)} messages"
|
||||
)
|
||||
return True
|
||||
|
||||
if truncated > 0:
|
||||
logger.info(
|
||||
f"🔧 Aggressive trim: truncated {truncated} large blocks "
|
||||
f"(no turns removed, only {len(turns)} turn(s) left)"
|
||||
)
|
||||
return True
|
||||
|
||||
# Nothing left to trim
|
||||
logger.warning("🔧 Aggressive trim: nothing to trim, will clear history")
|
||||
return False
|
||||
|
||||
def _trim_messages(self):
|
||||
"""
|
||||
智能清理消息历史,保持对话完整性
|
||||
|
||||
Reference in New Issue
Block a user