diff --git a/function_calls/llm.py b/function_calls/llm.py index a364068..39356ec 100644 --- a/function_calls/llm.py +++ b/function_calls/llm.py @@ -74,12 +74,22 @@ class FunctionCallLLM: openai_functions = self._build_functions_for_openai(functions) messages: List[Dict[str, Any]] = [] - system_prompt = ( + custom_prompt = None + system_msg = getattr(chat_model, "system_content_msg", None) + if isinstance(system_msg, dict): + custom_prompt = system_msg.get("content") + elif isinstance(system_msg, str): + custom_prompt = system_msg + if custom_prompt: + messages.append({"role": "system", "content": custom_prompt}) + + tool_prompt = ( "You are an assistant that can call tools. " "When you invoke a function, wait for the tool response before replying to the user. " "Only deliver a final answer once you have enough information." ) - messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "system", "content": tool_prompt}) + messages.append({"role": "user", "content": ctx.text}) for round_index in range(self.max_function_rounds):