Files
Bubbles/function_calls/llm.py
zihanjian d21d1c6e5c 1
2025-09-25 13:24:19 +08:00

260 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""LLM function-call orchestration utilities."""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from commands.context import MessageContext
from .spec import FunctionResult, FunctionSpec
logger = logging.getLogger(__name__)
@dataclass
class LLMRunResult:
"""Result of the LLM routing pipeline."""
handled: bool
final_response: Optional[str] = None
error: Optional[str] = None
class FunctionCallLLM:
"""Coordinate function-call capable models with router handlers."""
def __init__(self, max_function_rounds: int = 5) -> None:
self.logger = logger
self.max_function_rounds = max_function_rounds
def run(
self,
ctx: MessageContext,
functions: Dict[str, FunctionSpec],
executor: Callable[[FunctionSpec, Dict[str, Any]], FunctionResult],
formatter: Callable[[FunctionResult], str],
) -> LLMRunResult:
"""Execute the function-call loop and return the final assistant response."""
if not ctx.text:
return LLMRunResult(handled=False)
chat_model = getattr(ctx, "chat", None)
if not chat_model and ctx.robot:
chat_model = getattr(ctx.robot, "chat", None)
if not chat_model:
self.logger.error("无可用的AI模型")
return LLMRunResult(handled=False, error="no_model")
if not hasattr(chat_model, "call_with_functions"):
self.logger.error("当前模型不支持函数调用接口,请配置支持 function calling 的模型")
return LLMRunResult(handled=False, error="no_function_call_support")
try:
return self._run_native_loop(ctx, chat_model, functions, executor, formatter)
except Exception as exc: # pragma: no cover - safeguard
self.logger.error(f"LLM 调用失败: {exc}")
return LLMRunResult(handled=False, error=str(exc))
# ---------------------------------------------------------------------
# Native function-call workflow
# ---------------------------------------------------------------------
def _run_native_loop(
self,
ctx: MessageContext,
chat_model: Any,
functions: Dict[str, FunctionSpec],
executor: Callable[[FunctionSpec, Dict[str, Any]], FunctionResult],
formatter: Callable[[FunctionResult], str],
) -> LLMRunResult:
openai_functions = self._build_functions_for_openai(functions)
messages: List[Dict[str, Any]] = []
custom_prompt = None
system_msg = getattr(chat_model, "system_content_msg", None)
if isinstance(system_msg, dict):
custom_prompt = system_msg.get("content")
elif isinstance(system_msg, str):
custom_prompt = system_msg
if custom_prompt:
from datetime import datetime
now = datetime.now()
suffix = now.strftime("%Y-%m-%d %H:%M:%S %A")
enriched_prompt = (
f"{custom_prompt}\n\n"
f"当前时间:{suffix}\n"
)
messages.append({"role": "system", "content": enriched_prompt})
tool_prompt = (
"You are an assistant that can call tools. "
"When you invoke a function, wait for the tool response before replying to the user. "
"Only deliver a final answer once you have enough information."
)
messages.append({"role": "system", "content": tool_prompt})
history_messages = self._build_history_messages(ctx)
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": ctx.text})
for round_index in range(self.max_function_rounds):
response = chat_model.call_with_functions(
messages=messages,
functions=openai_functions,
wxid=ctx.get_receiver(),
)
if not getattr(response, "choices", None):
self.logger.warning("函数调用返回空响应")
return LLMRunResult(handled=False)
message = response.choices[0].message
assistant_entry = self._convert_assistant_message(message)
messages.append(assistant_entry)
tool_calls = getattr(message, "tool_calls", None) or []
if tool_calls:
for tool_call in tool_calls:
function_name = tool_call.function.name
if function_name not in functions:
self.logger.warning(f"模型请求未知函数: {function_name}")
tool_content = json.dumps(
{
"handled": False,
"messages": [f"Unknown function: {function_name}"],
"metadata": {"error": "unknown_function"},
},
ensure_ascii=False,
)
else:
try:
arguments = json.loads(tool_call.function.arguments or "{}")
except json.JSONDecodeError:
arguments = {}
spec = functions[function_name]
result = executor(spec, arguments)
tool_content = formatter(result)
self.logger.info(
"Function '%s' tool response payload: %s",
function_name,
tool_content,
)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_content,
}
)
continue
# 没有工具调用,认为模型给出了最终回答
final_content = message.content or ""
return LLMRunResult(handled=True, final_response=final_content)
self.logger.warning("达到最大函数调用轮数,未得到最终回答")
return LLMRunResult(handled=False, error="max_rounds")
@staticmethod
def _convert_assistant_message(message: Any) -> Dict[str, Any]:
entry: Dict[str, Any] = {
"role": "assistant",
"content": message.content,
}
tool_calls = getattr(message, "tool_calls", None)
if tool_calls:
entry["tool_calls"] = []
for tool_call in tool_calls:
entry["tool_calls"].append(
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
)
return entry
@staticmethod
def _build_functions_for_openai(functions: Dict[str, FunctionSpec]) -> List[Dict[str, Any]]:
openai_functions = []
for spec in functions.values():
openai_functions.append(
{
"name": spec.name,
"description": spec.description,
"parameters": spec.parameters_schema,
}
)
return openai_functions
def validate_arguments(self, arguments: Dict[str, Any], schema: Dict[str, Any]) -> bool:
try:
required_fields = schema.get("required", [])
properties = schema.get("properties", {})
for field in required_fields:
if field not in arguments:
self.logger.warning(f"缺少必需参数: {field}")
return False
for field, value in arguments.items():
if field not in properties:
continue
expected_type = properties[field].get("type")
if expected_type == "string" and not isinstance(value, str):
self.logger.warning(f"参数 {field} 类型不正确,期望 string得到 {type(value)}")
return False
if expected_type == "integer" and not isinstance(value, int):
self.logger.warning(f"参数 {field} 类型不正确,期望 integer得到 {type(value)}")
return False
if expected_type == "number" and not isinstance(value, (int, float)):
self.logger.warning(f"参数 {field} 类型不正确,期望 number得到 {type(value)}")
return False
return True
except Exception as exc:
self.logger.error(f"参数验证失败: {exc}")
return False
def _build_history_messages(self, ctx: MessageContext) -> List[Dict[str, Any]]:
message_summary = getattr(ctx.robot, "message_summary", None)
if not message_summary:
return []
try:
history = message_summary.get_messages(ctx.get_receiver())
except Exception as exc: # pragma: no cover - 防御
if ctx.logger:
ctx.logger.error(f"加载历史消息失败: {exc}")
return []
limit = getattr(ctx, "specific_max_history", None)
if limit is None:
limit = getattr(ctx.chat, "max_history_messages", None)
if limit is not None and limit > 0:
history = history[-limit:]
elif limit == 0:
history = []
formatted: List[Dict[str, Any]] = []
for item in history:
content = item.get("content")
if not content:
continue
role = "assistant" if item.get("sender_wxid") == ctx.robot_wxid else "user"
sender_name = item.get("sender", "未知用户")
content = f"[{sender_name}] {content}"
formatted.append({"role": role, "content": content})
return formatted