mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-12 02:05:06 +08:00
461 lines
16 KiB
Python
461 lines
16 KiB
Python
"""
|
|
Agent Stream Execution Module - Multi-turn reasoning based on tool-call
|
|
|
|
Provides streaming output, event system, and complete tool-call loop
|
|
"""
|
|
import json
|
|
import time
|
|
from typing import List, Dict, Any, Optional, Callable
|
|
|
|
from common.log import logger
|
|
from agent.protocol.models import LLMRequest, LLMModel
|
|
from agent.tools.base_tool import BaseTool, ToolResult
|
|
|
|
|
|
class AgentStreamExecutor:
|
|
"""
|
|
Agent Stream Executor
|
|
|
|
Handles multi-turn reasoning loop based on tool-call:
|
|
1. LLM generates response (may include tool calls)
|
|
2. Execute tools
|
|
3. Return results to LLM
|
|
4. Repeat until no more tool calls
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
agent, # Agent instance
|
|
model: LLMModel,
|
|
system_prompt: str,
|
|
tools: List[BaseTool],
|
|
max_turns: int = 50,
|
|
on_event: Optional[Callable] = None,
|
|
messages: Optional[List[Dict]] = None
|
|
):
|
|
"""
|
|
Initialize stream executor
|
|
|
|
Args:
|
|
agent: Agent instance (for accessing context)
|
|
model: LLM model
|
|
system_prompt: System prompt
|
|
tools: List of available tools
|
|
max_turns: Maximum number of turns
|
|
on_event: Event callback function
|
|
messages: Optional existing message history (for persistent conversations)
|
|
"""
|
|
self.agent = agent
|
|
self.model = model
|
|
self.system_prompt = system_prompt
|
|
# Convert tools list to dict
|
|
self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools
|
|
self.max_turns = max_turns
|
|
self.on_event = on_event
|
|
|
|
# Message history - use provided messages or create new list
|
|
self.messages = messages if messages is not None else []
|
|
|
|
def _emit_event(self, event_type: str, data: dict = None):
|
|
"""Emit event"""
|
|
if self.on_event:
|
|
try:
|
|
self.on_event({
|
|
"type": event_type,
|
|
"timestamp": time.time(),
|
|
"data": data or {}
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Event callback error: {e}")
|
|
|
|
def run_stream(self, user_message: str) -> str:
|
|
"""
|
|
Execute streaming reasoning loop
|
|
|
|
Args:
|
|
user_message: User message
|
|
|
|
Returns:
|
|
Final response text
|
|
"""
|
|
# Log user message
|
|
logger.info(f"\n{'='*50}")
|
|
logger.info(f"👤 用户: {user_message}")
|
|
logger.info(f"{'='*50}")
|
|
|
|
# Add user message (Claude format - use content blocks for consistency)
|
|
self.messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": user_message
|
|
}
|
|
]
|
|
})
|
|
|
|
self._emit_event("agent_start")
|
|
|
|
final_response = ""
|
|
turn = 0
|
|
|
|
try:
|
|
while turn < self.max_turns:
|
|
turn += 1
|
|
logger.info(f"\n{'='*50} 第 {turn} 轮 {'='*50}")
|
|
self._emit_event("turn_start", {"turn": turn})
|
|
|
|
# Check if memory flush is needed (before calling LLM)
|
|
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
|
|
usage = self.agent.last_usage
|
|
if usage and 'input_tokens' in usage:
|
|
current_tokens = usage.get('input_tokens', 0)
|
|
context_window = self.agent._get_model_context_window()
|
|
reserve_tokens = self.agent.context_reserve_tokens or 20000
|
|
|
|
if self.agent.memory_manager.should_flush_memory(
|
|
current_tokens=current_tokens,
|
|
context_window=context_window,
|
|
reserve_tokens=reserve_tokens
|
|
):
|
|
self._emit_event("memory_flush_start", {
|
|
"current_tokens": current_tokens,
|
|
"threshold": context_window - reserve_tokens - 4000
|
|
})
|
|
|
|
# TODO: Execute memory flush in background
|
|
# This would require async support
|
|
logger.info(f"Memory flush recommended at {current_tokens} tokens")
|
|
|
|
# Call LLM
|
|
assistant_msg, tool_calls = self._call_llm_stream()
|
|
final_response = assistant_msg
|
|
|
|
# No tool calls, end loop
|
|
if not tool_calls:
|
|
if assistant_msg:
|
|
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
|
|
logger.info(f"✅ 完成 (无工具调用)")
|
|
self._emit_event("turn_end", {
|
|
"turn": turn,
|
|
"has_tool_calls": False
|
|
})
|
|
break
|
|
|
|
# Log tool calls in compact format
|
|
tool_names = [tc['name'] for tc in tool_calls]
|
|
logger.info(f"🔧 调用工具: {', '.join(tool_names)}")
|
|
|
|
# Execute tools
|
|
tool_results = []
|
|
tool_result_blocks = []
|
|
|
|
for tool_call in tool_calls:
|
|
result = self._execute_tool(tool_call)
|
|
tool_results.append(result)
|
|
|
|
# Log tool result in compact format
|
|
status_emoji = "✅" if result.get("status") == "success" else "❌"
|
|
result_str = str(result.get('result', ''))
|
|
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
|
|
|
|
# Build tool result block (Claude format)
|
|
# Content should be a string representation of the result
|
|
result_content = json.dumps(result) if not isinstance(result, str) else result
|
|
tool_result_blocks.append({
|
|
"type": "tool_result",
|
|
"tool_use_id": tool_call["id"],
|
|
"content": result_content
|
|
})
|
|
|
|
# Add tool results to message history as user message (Claude format)
|
|
self.messages.append({
|
|
"role": "user",
|
|
"content": tool_result_blocks
|
|
})
|
|
|
|
self._emit_event("turn_end", {
|
|
"turn": turn,
|
|
"has_tool_calls": True,
|
|
"tool_count": len(tool_calls)
|
|
})
|
|
|
|
if turn >= self.max_turns:
|
|
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Agent执行错误: {e}")
|
|
self._emit_event("error", {"error": str(e)})
|
|
raise
|
|
|
|
finally:
|
|
logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n")
|
|
self._emit_event("agent_end", {"final_response": final_response})
|
|
|
|
return final_response
|
|
|
|
def _call_llm_stream(self) -> tuple[str, List[Dict]]:
|
|
"""
|
|
Call LLM with streaming
|
|
|
|
Returns:
|
|
(response_text, tool_calls)
|
|
"""
|
|
# Trim messages if needed (using agent's context management)
|
|
self._trim_messages()
|
|
|
|
# Prepare messages
|
|
messages = self._prepare_messages()
|
|
|
|
# Debug: log message structure
|
|
logger.debug(f"Sending {len(messages)} messages to LLM")
|
|
for i, msg in enumerate(messages):
|
|
role = msg.get("role", "unknown")
|
|
content = msg.get("content", "")
|
|
if isinstance(content, list):
|
|
content_types = [c.get("type") for c in content if isinstance(c, dict)]
|
|
logger.debug(f" Message {i}: role={role}, content_blocks={content_types}")
|
|
else:
|
|
logger.debug(f" Message {i}: role={role}, content_length={len(str(content))}")
|
|
|
|
# Prepare tool definitions (OpenAI/Claude format)
|
|
tools_schema = None
|
|
if self.tools:
|
|
tools_schema = []
|
|
for tool in self.tools.values():
|
|
tools_schema.append({
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"input_schema": tool.params # Claude uses input_schema
|
|
})
|
|
|
|
# Create request
|
|
request = LLMRequest(
|
|
messages=messages,
|
|
temperature=0,
|
|
stream=True,
|
|
tools=tools_schema,
|
|
system=self.system_prompt # Pass system prompt separately for Claude API
|
|
)
|
|
|
|
self._emit_event("message_start", {"role": "assistant"})
|
|
|
|
# Streaming response
|
|
full_content = ""
|
|
tool_calls_buffer = {} # {index: {id, name, arguments}}
|
|
|
|
try:
|
|
stream = self.model.call_stream(request)
|
|
|
|
for chunk in stream:
|
|
# Check for errors
|
|
if isinstance(chunk, dict) and chunk.get("error"):
|
|
error_msg = chunk.get("message", "Unknown error")
|
|
status_code = chunk.get("status_code", "N/A")
|
|
logger.error(f"API Error: {error_msg} (Status: {status_code})")
|
|
logger.error(f"Full error chunk: {chunk}")
|
|
raise Exception(f"{error_msg} (Status: {status_code})")
|
|
|
|
# Parse chunk
|
|
if isinstance(chunk, dict) and "choices" in chunk:
|
|
choice = chunk["choices"][0]
|
|
delta = choice.get("delta", {})
|
|
|
|
# Handle text content
|
|
if "content" in delta and delta["content"]:
|
|
content_delta = delta["content"]
|
|
full_content += content_delta
|
|
self._emit_event("message_update", {"delta": content_delta})
|
|
|
|
# Handle tool calls
|
|
if "tool_calls" in delta:
|
|
for tc_delta in delta["tool_calls"]:
|
|
index = tc_delta.get("index", 0)
|
|
|
|
if index not in tool_calls_buffer:
|
|
tool_calls_buffer[index] = {
|
|
"id": "",
|
|
"name": "",
|
|
"arguments": ""
|
|
}
|
|
|
|
if "id" in tc_delta:
|
|
tool_calls_buffer[index]["id"] = tc_delta["id"]
|
|
|
|
if "function" in tc_delta:
|
|
func = tc_delta["function"]
|
|
if "name" in func:
|
|
tool_calls_buffer[index]["name"] = func["name"]
|
|
if "arguments" in func:
|
|
tool_calls_buffer[index]["arguments"] += func["arguments"]
|
|
|
|
except Exception as e:
|
|
logger.error(f"LLM call error: {e}")
|
|
raise
|
|
|
|
# Parse tool calls
|
|
tool_calls = []
|
|
for idx in sorted(tool_calls_buffer.keys()):
|
|
tc = tool_calls_buffer[idx]
|
|
try:
|
|
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Failed to parse tool arguments: {tc['arguments']}")
|
|
arguments = {}
|
|
|
|
tool_calls.append({
|
|
"id": tc["id"],
|
|
"name": tc["name"],
|
|
"arguments": arguments
|
|
})
|
|
|
|
# Add assistant message to history (Claude format uses content blocks)
|
|
assistant_msg = {"role": "assistant", "content": []}
|
|
|
|
# Add text content block if present
|
|
if full_content:
|
|
assistant_msg["content"].append({
|
|
"type": "text",
|
|
"text": full_content
|
|
})
|
|
|
|
# Add tool_use blocks if present
|
|
if tool_calls:
|
|
for tc in tool_calls:
|
|
assistant_msg["content"].append({
|
|
"type": "tool_use",
|
|
"id": tc["id"],
|
|
"name": tc["name"],
|
|
"input": tc["arguments"]
|
|
})
|
|
|
|
# Only append if content is not empty
|
|
if assistant_msg["content"]:
|
|
self.messages.append(assistant_msg)
|
|
|
|
self._emit_event("message_end", {
|
|
"content": full_content,
|
|
"tool_calls": tool_calls
|
|
})
|
|
|
|
return full_content, tool_calls
|
|
|
|
def _execute_tool(self, tool_call: Dict) -> Dict[str, Any]:
|
|
"""
|
|
Execute tool
|
|
|
|
Args:
|
|
tool_call: {"id": str, "name": str, "arguments": dict}
|
|
|
|
Returns:
|
|
Tool execution result
|
|
"""
|
|
tool_name = tool_call["name"]
|
|
tool_id = tool_call["id"]
|
|
arguments = tool_call["arguments"]
|
|
|
|
self._emit_event("tool_execution_start", {
|
|
"tool_call_id": tool_id,
|
|
"tool_name": tool_name,
|
|
"arguments": arguments
|
|
})
|
|
|
|
try:
|
|
tool = self.tools.get(tool_name)
|
|
if not tool:
|
|
raise ValueError(f"Tool '{tool_name}' not found")
|
|
|
|
# Set tool context
|
|
tool.model = self.model
|
|
tool.context = self.agent
|
|
|
|
# Execute tool
|
|
start_time = time.time()
|
|
result: ToolResult = tool.execute_tool(arguments)
|
|
execution_time = time.time() - start_time
|
|
|
|
result_dict = {
|
|
"status": result.status,
|
|
"result": result.result,
|
|
"execution_time": execution_time
|
|
}
|
|
|
|
self._emit_event("tool_execution_end", {
|
|
"tool_call_id": tool_id,
|
|
"tool_name": tool_name,
|
|
**result_dict
|
|
})
|
|
|
|
return result_dict
|
|
|
|
except Exception as e:
|
|
logger.error(f"Tool execution error: {e}")
|
|
error_result = {
|
|
"status": "error",
|
|
"result": str(e),
|
|
"execution_time": 0
|
|
}
|
|
self._emit_event("tool_execution_end", {
|
|
"tool_call_id": tool_id,
|
|
"tool_name": tool_name,
|
|
**error_result
|
|
})
|
|
return error_result
|
|
|
|
def _trim_messages(self):
|
|
"""
|
|
Trim message history to stay within context limits.
|
|
Uses agent's context management configuration.
|
|
"""
|
|
if not self.messages or not self.agent:
|
|
return
|
|
|
|
# Get context window and reserve tokens from agent
|
|
context_window = self.agent._get_model_context_window()
|
|
reserve_tokens = self.agent._get_context_reserve_tokens()
|
|
max_tokens = context_window - reserve_tokens
|
|
|
|
# Estimate current tokens
|
|
current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages)
|
|
|
|
# Add system prompt tokens
|
|
system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt})
|
|
current_tokens += system_tokens
|
|
|
|
# If under limit, no need to trim
|
|
if current_tokens <= max_tokens:
|
|
return
|
|
|
|
# Keep messages from newest, accumulating tokens
|
|
available_tokens = max_tokens - system_tokens
|
|
kept_messages = []
|
|
accumulated_tokens = 0
|
|
|
|
for msg in reversed(self.messages):
|
|
msg_tokens = self.agent._estimate_message_tokens(msg)
|
|
if accumulated_tokens + msg_tokens <= available_tokens:
|
|
kept_messages.insert(0, msg)
|
|
accumulated_tokens += msg_tokens
|
|
else:
|
|
break
|
|
|
|
old_count = len(self.messages)
|
|
self.messages = kept_messages
|
|
new_count = len(self.messages)
|
|
|
|
if old_count > new_count:
|
|
logger.info(
|
|
f"Context trimmed: {old_count} -> {new_count} messages "
|
|
f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, "
|
|
f"limit: {max_tokens})"
|
|
)
|
|
|
|
def _prepare_messages(self) -> List[Dict[str, Any]]:
|
|
"""
|
|
Prepare messages to send to LLM
|
|
|
|
Note: For Claude API, system prompt should be passed separately via system parameter,
|
|
not as a message. The AgentLLMModel will handle this.
|
|
"""
|
|
# Don't add system message here - it will be handled separately by the LLM adapter
|
|
return self.messages |