Files
chatgpt-on-wechat/agent/protocol/agent_stream.py
2026-01-30 09:53:46 +08:00

461 lines
16 KiB
Python

"""
Agent Stream Execution Module - Multi-turn reasoning based on tool-call
Provides streaming output, event system, and complete tool-call loop
"""
import json
import time
from typing import List, Dict, Any, Optional, Callable
from common.log import logger
from agent.protocol.models import LLMRequest, LLMModel
from agent.tools.base_tool import BaseTool, ToolResult
class AgentStreamExecutor:
"""
Agent Stream Executor
Handles multi-turn reasoning loop based on tool-call:
1. LLM generates response (may include tool calls)
2. Execute tools
3. Return results to LLM
4. Repeat until no more tool calls
"""
def __init__(
self,
agent, # Agent instance
model: LLMModel,
system_prompt: str,
tools: List[BaseTool],
max_turns: int = 50,
on_event: Optional[Callable] = None,
messages: Optional[List[Dict]] = None
):
"""
Initialize stream executor
Args:
agent: Agent instance (for accessing context)
model: LLM model
system_prompt: System prompt
tools: List of available tools
max_turns: Maximum number of turns
on_event: Event callback function
messages: Optional existing message history (for persistent conversations)
"""
self.agent = agent
self.model = model
self.system_prompt = system_prompt
# Convert tools list to dict
self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools
self.max_turns = max_turns
self.on_event = on_event
# Message history - use provided messages or create new list
self.messages = messages if messages is not None else []
def _emit_event(self, event_type: str, data: dict = None):
"""Emit event"""
if self.on_event:
try:
self.on_event({
"type": event_type,
"timestamp": time.time(),
"data": data or {}
})
except Exception as e:
logger.error(f"Event callback error: {e}")
def run_stream(self, user_message: str) -> str:
"""
Execute streaming reasoning loop
Args:
user_message: User message
Returns:
Final response text
"""
# Log user message
logger.info(f"\n{'='*50}")
logger.info(f"👤 用户: {user_message}")
logger.info(f"{'='*50}")
# Add user message (Claude format - use content blocks for consistency)
self.messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": user_message
}
]
})
self._emit_event("agent_start")
final_response = ""
turn = 0
try:
while turn < self.max_turns:
turn += 1
logger.info(f"\n{'='*50}{turn}{'='*50}")
self._emit_event("turn_start", {"turn": turn})
# Check if memory flush is needed (before calling LLM)
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
usage = self.agent.last_usage
if usage and 'input_tokens' in usage:
current_tokens = usage.get('input_tokens', 0)
context_window = self.agent._get_model_context_window()
reserve_tokens = self.agent.context_reserve_tokens or 20000
if self.agent.memory_manager.should_flush_memory(
current_tokens=current_tokens,
context_window=context_window,
reserve_tokens=reserve_tokens
):
self._emit_event("memory_flush_start", {
"current_tokens": current_tokens,
"threshold": context_window - reserve_tokens - 4000
})
# TODO: Execute memory flush in background
# This would require async support
logger.info(f"Memory flush recommended at {current_tokens} tokens")
# Call LLM
assistant_msg, tool_calls = self._call_llm_stream()
final_response = assistant_msg
# No tool calls, end loop
if not tool_calls:
if assistant_msg:
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
logger.info(f"✅ 完成 (无工具调用)")
self._emit_event("turn_end", {
"turn": turn,
"has_tool_calls": False
})
break
# Log tool calls in compact format
tool_names = [tc['name'] for tc in tool_calls]
logger.info(f"🔧 调用工具: {', '.join(tool_names)}")
# Execute tools
tool_results = []
tool_result_blocks = []
for tool_call in tool_calls:
result = self._execute_tool(tool_call)
tool_results.append(result)
# Log tool result in compact format
status_emoji = "" if result.get("status") == "success" else ""
result_str = str(result.get('result', ''))
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
# Build tool result block (Claude format)
# Content should be a string representation of the result
result_content = json.dumps(result) if not isinstance(result, str) else result
tool_result_blocks.append({
"type": "tool_result",
"tool_use_id": tool_call["id"],
"content": result_content
})
# Add tool results to message history as user message (Claude format)
self.messages.append({
"role": "user",
"content": tool_result_blocks
})
self._emit_event("turn_end", {
"turn": turn,
"has_tool_calls": True,
"tool_count": len(tool_calls)
})
if turn >= self.max_turns:
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
except Exception as e:
logger.error(f"❌ Agent执行错误: {e}")
self._emit_event("error", {"error": str(e)})
raise
finally:
logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n")
self._emit_event("agent_end", {"final_response": final_response})
return final_response
def _call_llm_stream(self) -> tuple[str, List[Dict]]:
"""
Call LLM with streaming
Returns:
(response_text, tool_calls)
"""
# Trim messages if needed (using agent's context management)
self._trim_messages()
# Prepare messages
messages = self._prepare_messages()
# Debug: log message structure
logger.debug(f"Sending {len(messages)} messages to LLM")
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
if isinstance(content, list):
content_types = [c.get("type") for c in content if isinstance(c, dict)]
logger.debug(f" Message {i}: role={role}, content_blocks={content_types}")
else:
logger.debug(f" Message {i}: role={role}, content_length={len(str(content))}")
# Prepare tool definitions (OpenAI/Claude format)
tools_schema = None
if self.tools:
tools_schema = []
for tool in self.tools.values():
tools_schema.append({
"name": tool.name,
"description": tool.description,
"input_schema": tool.params # Claude uses input_schema
})
# Create request
request = LLMRequest(
messages=messages,
temperature=0,
stream=True,
tools=tools_schema,
system=self.system_prompt # Pass system prompt separately for Claude API
)
self._emit_event("message_start", {"role": "assistant"})
# Streaming response
full_content = ""
tool_calls_buffer = {} # {index: {id, name, arguments}}
try:
stream = self.model.call_stream(request)
for chunk in stream:
# Check for errors
if isinstance(chunk, dict) and chunk.get("error"):
error_msg = chunk.get("message", "Unknown error")
status_code = chunk.get("status_code", "N/A")
logger.error(f"API Error: {error_msg} (Status: {status_code})")
logger.error(f"Full error chunk: {chunk}")
raise Exception(f"{error_msg} (Status: {status_code})")
# Parse chunk
if isinstance(chunk, dict) and "choices" in chunk:
choice = chunk["choices"][0]
delta = choice.get("delta", {})
# Handle text content
if "content" in delta and delta["content"]:
content_delta = delta["content"]
full_content += content_delta
self._emit_event("message_update", {"delta": content_delta})
# Handle tool calls
if "tool_calls" in delta:
for tc_delta in delta["tool_calls"]:
index = tc_delta.get("index", 0)
if index not in tool_calls_buffer:
tool_calls_buffer[index] = {
"id": "",
"name": "",
"arguments": ""
}
if "id" in tc_delta:
tool_calls_buffer[index]["id"] = tc_delta["id"]
if "function" in tc_delta:
func = tc_delta["function"]
if "name" in func:
tool_calls_buffer[index]["name"] = func["name"]
if "arguments" in func:
tool_calls_buffer[index]["arguments"] += func["arguments"]
except Exception as e:
logger.error(f"LLM call error: {e}")
raise
# Parse tool calls
tool_calls = []
for idx in sorted(tool_calls_buffer.keys()):
tc = tool_calls_buffer[idx]
try:
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
except json.JSONDecodeError as e:
logger.error(f"Failed to parse tool arguments: {tc['arguments']}")
arguments = {}
tool_calls.append({
"id": tc["id"],
"name": tc["name"],
"arguments": arguments
})
# Add assistant message to history (Claude format uses content blocks)
assistant_msg = {"role": "assistant", "content": []}
# Add text content block if present
if full_content:
assistant_msg["content"].append({
"type": "text",
"text": full_content
})
# Add tool_use blocks if present
if tool_calls:
for tc in tool_calls:
assistant_msg["content"].append({
"type": "tool_use",
"id": tc["id"],
"name": tc["name"],
"input": tc["arguments"]
})
# Only append if content is not empty
if assistant_msg["content"]:
self.messages.append(assistant_msg)
self._emit_event("message_end", {
"content": full_content,
"tool_calls": tool_calls
})
return full_content, tool_calls
def _execute_tool(self, tool_call: Dict) -> Dict[str, Any]:
"""
Execute tool
Args:
tool_call: {"id": str, "name": str, "arguments": dict}
Returns:
Tool execution result
"""
tool_name = tool_call["name"]
tool_id = tool_call["id"]
arguments = tool_call["arguments"]
self._emit_event("tool_execution_start", {
"tool_call_id": tool_id,
"tool_name": tool_name,
"arguments": arguments
})
try:
tool = self.tools.get(tool_name)
if not tool:
raise ValueError(f"Tool '{tool_name}' not found")
# Set tool context
tool.model = self.model
tool.context = self.agent
# Execute tool
start_time = time.time()
result: ToolResult = tool.execute_tool(arguments)
execution_time = time.time() - start_time
result_dict = {
"status": result.status,
"result": result.result,
"execution_time": execution_time
}
self._emit_event("tool_execution_end", {
"tool_call_id": tool_id,
"tool_name": tool_name,
**result_dict
})
return result_dict
except Exception as e:
logger.error(f"Tool execution error: {e}")
error_result = {
"status": "error",
"result": str(e),
"execution_time": 0
}
self._emit_event("tool_execution_end", {
"tool_call_id": tool_id,
"tool_name": tool_name,
**error_result
})
return error_result
def _trim_messages(self):
"""
Trim message history to stay within context limits.
Uses agent's context management configuration.
"""
if not self.messages or not self.agent:
return
# Get context window and reserve tokens from agent
context_window = self.agent._get_model_context_window()
reserve_tokens = self.agent._get_context_reserve_tokens()
max_tokens = context_window - reserve_tokens
# Estimate current tokens
current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages)
# Add system prompt tokens
system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt})
current_tokens += system_tokens
# If under limit, no need to trim
if current_tokens <= max_tokens:
return
# Keep messages from newest, accumulating tokens
available_tokens = max_tokens - system_tokens
kept_messages = []
accumulated_tokens = 0
for msg in reversed(self.messages):
msg_tokens = self.agent._estimate_message_tokens(msg)
if accumulated_tokens + msg_tokens <= available_tokens:
kept_messages.insert(0, msg)
accumulated_tokens += msg_tokens
else:
break
old_count = len(self.messages)
self.messages = kept_messages
new_count = len(self.messages)
if old_count > new_count:
logger.info(
f"Context trimmed: {old_count} -> {new_count} messages "
f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, "
f"limit: {max_tokens})"
)
def _prepare_messages(self) -> List[Dict[str, Any]]:
"""
Prepare messages to send to LLM
Note: For Claude API, system prompt should be passed separately via system parameter,
not as a message. The AgentLLMModel will handle this.
"""
# Don't add system message here - it will be handled separately by the LLM adapter
return self.messages