mirror of
https://github.com/Zippland/Bubbles.git
synced 2026-01-19 09:41:17 +08:00
264 lines
10 KiB
Python
264 lines
10 KiB
Python
"""LLM function-call orchestration utilities."""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from dataclasses import dataclass
|
||
from typing import Any, Callable, Dict, List, Optional
|
||
|
||
from commands.context import MessageContext
|
||
|
||
from .spec import FunctionResult, FunctionSpec
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class LLMRunResult:
|
||
"""Result of the LLM routing pipeline."""
|
||
|
||
handled: bool
|
||
final_response: Optional[str] = None
|
||
error: Optional[str] = None
|
||
|
||
|
||
class FunctionCallLLM:
|
||
"""Coordinate function-call capable models with router handlers."""
|
||
|
||
def __init__(self, max_function_rounds: int = 5) -> None:
|
||
self.logger = logger
|
||
self.max_function_rounds = max_function_rounds
|
||
|
||
def run(
|
||
self,
|
||
ctx: MessageContext,
|
||
functions: Dict[str, FunctionSpec],
|
||
executor: Callable[[FunctionSpec, Dict[str, Any]], FunctionResult],
|
||
formatter: Callable[[FunctionResult], str],
|
||
) -> LLMRunResult:
|
||
"""Execute the function-call loop and return the final assistant response."""
|
||
if not ctx.text:
|
||
return LLMRunResult(handled=False)
|
||
|
||
chat_model = getattr(ctx, "chat", None)
|
||
if not chat_model and ctx.robot:
|
||
chat_model = getattr(ctx.robot, "chat", None)
|
||
|
||
if not chat_model:
|
||
self.logger.error("无可用的AI模型")
|
||
return LLMRunResult(handled=False, error="no_model")
|
||
|
||
try:
|
||
if hasattr(chat_model, "call_with_functions"):
|
||
return self._run_native_loop(ctx, chat_model, functions, executor, formatter)
|
||
return self._run_prompt_loop(ctx, chat_model, functions, executor)
|
||
except Exception as exc: # pragma: no cover - safeguard
|
||
self.logger.error(f"LLM 调用失败: {exc}")
|
||
return LLMRunResult(handled=False, error=str(exc))
|
||
|
||
# ---------------------------------------------------------------------
|
||
# Native function-call workflow
|
||
# ---------------------------------------------------------------------
|
||
|
||
def _run_native_loop(
|
||
self,
|
||
ctx: MessageContext,
|
||
chat_model: Any,
|
||
functions: Dict[str, FunctionSpec],
|
||
executor: Callable[[FunctionSpec, Dict[str, Any]], FunctionResult],
|
||
formatter: Callable[[FunctionResult], str],
|
||
) -> LLMRunResult:
|
||
openai_functions = self._build_functions_for_openai(functions)
|
||
messages: List[Dict[str, Any]] = []
|
||
|
||
system_prompt = (
|
||
"You are an assistant that can call tools. "
|
||
"When you invoke a function, wait for the tool response before replying to the user. "
|
||
"Only deliver a final answer once you have enough information."
|
||
)
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": ctx.text})
|
||
|
||
for round_index in range(self.max_function_rounds):
|
||
response = chat_model.call_with_functions(
|
||
messages=messages,
|
||
functions=openai_functions,
|
||
wxid=ctx.get_receiver(),
|
||
)
|
||
|
||
if not getattr(response, "choices", None):
|
||
self.logger.warning("函数调用返回空响应")
|
||
return LLMRunResult(handled=False)
|
||
|
||
message = response.choices[0].message
|
||
assistant_entry = self._convert_assistant_message(message)
|
||
messages.append(assistant_entry)
|
||
|
||
tool_calls = getattr(message, "tool_calls", None) or []
|
||
if tool_calls:
|
||
for tool_call in tool_calls:
|
||
function_name = tool_call.function.name
|
||
if function_name not in functions:
|
||
self.logger.warning(f"模型请求未知函数: {function_name}")
|
||
tool_content = json.dumps(
|
||
{
|
||
"handled": False,
|
||
"messages": [f"Unknown function: {function_name}"],
|
||
"metadata": {"error": "unknown_function"},
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
else:
|
||
try:
|
||
arguments = json.loads(tool_call.function.arguments or "{}")
|
||
except json.JSONDecodeError:
|
||
arguments = {}
|
||
spec = functions[function_name]
|
||
result = executor(spec, arguments)
|
||
tool_content = formatter(result)
|
||
messages.append(
|
||
{
|
||
"role": "tool",
|
||
"tool_call_id": tool_call.id,
|
||
"content": tool_content,
|
||
}
|
||
)
|
||
continue
|
||
|
||
# 没有工具调用,认为模型给出了最终回答
|
||
final_content = message.content or ""
|
||
return LLMRunResult(handled=True, final_response=final_content)
|
||
|
||
self.logger.warning("达到最大函数调用轮数,未得到最终回答")
|
||
return LLMRunResult(handled=False, error="max_rounds")
|
||
|
||
# ---------------------------------------------------------------------
|
||
# Prompt-based fallback workflow
|
||
# ---------------------------------------------------------------------
|
||
|
||
def _run_prompt_loop(
|
||
self,
|
||
ctx: MessageContext,
|
||
chat_model: Any,
|
||
functions: Dict[str, FunctionSpec],
|
||
executor: Callable[[FunctionSpec, Dict[str, Any]], FunctionResult],
|
||
) -> LLMRunResult:
|
||
system_prompt = self._build_prompt_system_text(functions)
|
||
user_input = f"用户输入:{ctx.text}"
|
||
|
||
ai_response = chat_model.get_answer(
|
||
user_input,
|
||
wxid=ctx.get_receiver(),
|
||
system_prompt_override=system_prompt,
|
||
)
|
||
|
||
json_match = re.search(r"\{.*\}", ai_response, re.DOTALL)
|
||
if not json_match:
|
||
self.logger.warning(f"提示词模式下无法解析JSON: {ai_response}")
|
||
return LLMRunResult(handled=False)
|
||
|
||
try:
|
||
decision = json.loads(json_match.group(0))
|
||
except json.JSONDecodeError as exc:
|
||
self.logger.error(f"提示词模式 JSON 解析失败: {exc}")
|
||
return LLMRunResult(handled=False)
|
||
|
||
action_type = decision.get("action_type")
|
||
if action_type == "chat":
|
||
# 提示词模式下无法获得模型最终回答,交给上层兜底
|
||
return LLMRunResult(handled=False)
|
||
|
||
if action_type != "function":
|
||
self.logger.warning(f"未知的action_type: {action_type}")
|
||
return LLMRunResult(handled=False)
|
||
|
||
function_name = decision.get("function_name")
|
||
if function_name not in functions:
|
||
self.logger.warning(f"未知的功能名 - {function_name}")
|
||
return LLMRunResult(handled=False)
|
||
|
||
arguments = decision.get("arguments", {})
|
||
result = executor(functions[function_name], arguments)
|
||
if not result.handled:
|
||
return LLMRunResult(handled=False)
|
||
|
||
return LLMRunResult(handled=True, final_response="\n".join(result.messages))
|
||
|
||
# ------------------------------------------------------------------
|
||
# Helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
@staticmethod
|
||
def _convert_assistant_message(message: Any) -> Dict[str, Any]:
|
||
entry: Dict[str, Any] = {
|
||
"role": "assistant",
|
||
"content": message.content,
|
||
}
|
||
tool_calls = getattr(message, "tool_calls", None)
|
||
if tool_calls:
|
||
entry["tool_calls"] = []
|
||
for tool_call in tool_calls:
|
||
entry["tool_calls"].append(
|
||
{
|
||
"id": tool_call.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_call.function.name,
|
||
"arguments": tool_call.function.arguments,
|
||
},
|
||
}
|
||
)
|
||
return entry
|
||
|
||
@staticmethod
|
||
def _build_functions_for_openai(functions: Dict[str, FunctionSpec]) -> List[Dict[str, Any]]:
|
||
openai_functions = []
|
||
for spec in functions.values():
|
||
openai_functions.append(
|
||
{
|
||
"name": spec.name,
|
||
"description": spec.description,
|
||
"parameters": spec.parameters_schema,
|
||
}
|
||
)
|
||
return openai_functions
|
||
|
||
@staticmethod
|
||
def _build_prompt_system_text(functions: Dict[str, FunctionSpec]) -> str:
|
||
prompt = """你是一个智能路由助手。根据用户输入判断是否需要调用以下函数之一。"""
|
||
for spec in functions.values():
|
||
prompt += f"\n- {spec.name}: {spec.description}"
|
||
prompt += """
|
||
请严格输出JSON:{"action_type": "chat"} 或 {"action_type": "function", "function_name": "...", "arguments": {...}}
|
||
"""
|
||
return prompt
|
||
|
||
def validate_arguments(self, arguments: Dict[str, Any], schema: Dict[str, Any]) -> bool:
|
||
try:
|
||
required_fields = schema.get("required", [])
|
||
properties = schema.get("properties", {})
|
||
|
||
for field in required_fields:
|
||
if field not in arguments:
|
||
self.logger.warning(f"缺少必需参数: {field}")
|
||
return False
|
||
|
||
for field, value in arguments.items():
|
||
if field not in properties:
|
||
continue
|
||
expected_type = properties[field].get("type")
|
||
if expected_type == "string" and not isinstance(value, str):
|
||
self.logger.warning(f"参数 {field} 类型不正确,期望 string,得到 {type(value)}")
|
||
return False
|
||
if expected_type == "integer" and not isinstance(value, int):
|
||
self.logger.warning(f"参数 {field} 类型不正确,期望 integer,得到 {type(value)}")
|
||
return False
|
||
if expected_type == "number" and not isinstance(value, (int, float)):
|
||
self.logger.warning(f"参数 {field} 类型不正确,期望 number,得到 {type(value)}")
|
||
return False
|
||
return True
|
||
except Exception as exc:
|
||
self.logger.error(f"参数验证失败: {exc}")
|
||
return False
|