重构了一下

This commit is contained in:
zihanjian
2025-09-25 11:54:16 +08:00
parent 4419f16843
commit 48cf486725
10 changed files with 181 additions and 1217 deletions

View File

@@ -49,10 +49,12 @@ class FunctionCallLLM:
self.logger.error("无可用的AI模型")
return LLMRunResult(handled=False, error="no_model")
if not hasattr(chat_model, "call_with_functions"):
self.logger.error("当前模型不支持函数调用接口,请配置支持 function calling 的模型")
return LLMRunResult(handled=False, error="no_function_call_support")
try:
if hasattr(chat_model, "call_with_functions"):
return self._run_native_loop(ctx, chat_model, functions, executor, formatter)
return self._run_prompt_loop(ctx, chat_model, functions, executor)
return self._run_native_loop(ctx, chat_model, functions, executor, formatter)
except Exception as exc: # pragma: no cover - safeguard
self.logger.error(f"LLM 调用失败: {exc}")
return LLMRunResult(handled=False, error=str(exc))
@@ -133,62 +135,6 @@ class FunctionCallLLM:
self.logger.warning("达到最大函数调用轮数,未得到最终回答")
return LLMRunResult(handled=False, error="max_rounds")
# ---------------------------------------------------------------------
# Prompt-based fallback workflow
# ---------------------------------------------------------------------
def _run_prompt_loop(
self,
ctx: MessageContext,
chat_model: Any,
functions: Dict[str, FunctionSpec],
executor: Callable[[FunctionSpec, Dict[str, Any]], FunctionResult],
) -> LLMRunResult:
system_prompt = self._build_prompt_system_text(functions)
user_input = f"用户输入:{ctx.text}"
ai_response = chat_model.get_answer(
user_input,
wxid=ctx.get_receiver(),
system_prompt_override=system_prompt,
)
json_match = re.search(r"\{.*\}", ai_response, re.DOTALL)
if not json_match:
self.logger.warning(f"提示词模式下无法解析JSON: {ai_response}")
return LLMRunResult(handled=False)
try:
decision = json.loads(json_match.group(0))
except json.JSONDecodeError as exc:
self.logger.error(f"提示词模式 JSON 解析失败: {exc}")
return LLMRunResult(handled=False)
action_type = decision.get("action_type")
if action_type == "chat":
# 提示词模式下无法获得模型最终回答,交给上层兜底
return LLMRunResult(handled=False)
if action_type != "function":
self.logger.warning(f"未知的action_type: {action_type}")
return LLMRunResult(handled=False)
function_name = decision.get("function_name")
if function_name not in functions:
self.logger.warning(f"未知的功能名 - {function_name}")
return LLMRunResult(handled=False)
arguments = decision.get("arguments", {})
result = executor(functions[function_name], arguments)
if not result.handled:
return LLMRunResult(handled=False)
return LLMRunResult(handled=True, final_response="\n".join(result.messages))
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _convert_assistant_message(message: Any) -> Dict[str, Any]:
entry: Dict[str, Any] = {
@@ -224,16 +170,6 @@ class FunctionCallLLM:
)
return openai_functions
@staticmethod
def _build_prompt_system_text(functions: Dict[str, FunctionSpec]) -> str:
prompt = """你是一个智能路由助手。根据用户输入判断是否需要调用以下函数之一。"""
for spec in functions.values():
prompt += f"\n- {spec.name}: {spec.description}"
prompt += """
请严格输出JSON{"action_type": "chat"} 或 {"action_type": "function", "function_name": "...", "arguments": {...}}
"""
return prompt
def validate_arguments(self, arguments: Dict[str, Any], schema: Dict[str, Any]) -> bool:
try:
required_fields = schema.get("required", [])