Files
Bubbles/ai_providers/ai_chatglm.py
2025-04-23 13:30:10 +08:00

200 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import os
import random
import logging
from datetime import datetime
from typing import Optional
import httpx
from openai import OpenAI
from ai_providers.chatglm.code_kernel import CodeKernel, execute
from ai_providers.chatglm.tool_registry import dispatch_tool, extract_code, get_tools
from wcferry import Wcf
# 获取模块级 logger
logger = logging.getLogger(__name__)
functions = get_tools()
class ChatGLM:
def __init__(self, config={}, wcf: Optional[Wcf] = None, max_retry=5) -> None:
key = config.get("key", 'empty')
api = config.get("api")
proxy = config.get("proxy")
if proxy:
self.client = OpenAI(api_key=key, base_url=api, http_client=httpx.Client(proxy=proxy))
else:
self.client = OpenAI(api_key=key, base_url=api)
self.conversation_list = {}
self.chat_type = {}
self.max_retry = max_retry
self.wcf = wcf
self.filePath = config["file_path"]
self.kernel = CodeKernel()
self.system_content_msg = {"chat": [{"role": "system", "content": config["prompt"]}],
"tool": [{"role": "system",
"content": "Answer the following questions as best as you can. You have access to the following tools:"}],
"code": [{"role": "system",
"content": "你是一位智能AI助手你叫ChatGLM你连接着一台电脑但请注意不能联网。在使用Python解决任务时你可以运行代码并得到结果如果运行结果有错误你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件文件默认存储路径是{}".format(
self.filePath)}]}
def __repr__(self):
return 'ChatGLM'
@staticmethod
def value_check(conf: dict) -> bool:
if conf:
if conf.get("api") and conf.get("prompt") and conf.get("file_path"):
return True
return False
def get_answer(self, question: str, wxid: str) -> str:
# wxid或者roomid,个人时为微信id群消息时为群id
if '#帮助' == question:
return '本助手有三种模式,#聊天模式 = #1 #工具模式 = #2 #代码模式 = #3 , #清除模式会话 = #4 , #清除全部会话 = #5 可用发送#对应模式 或者 #编号 进行切换'
elif '#聊天模式' == question or '#1' == question:
self.chat_type[wxid] = 'chat'
return '已切换#聊天模式'
elif '#工具模式' == question or '#2' == question:
self.chat_type[wxid] = 'tool'
return '已切换#工具模式 \n工具有:查看天气,日期,新闻,comfyUI文生图。例如\n帮我生成一张小鸟的图片,提示词必须是英文'
elif '#代码模式' == question or '#3' == question:
self.chat_type[wxid] = 'code'
return '已切换#代码模式 \n代码模式可以用于写python代码例如\n用python画一个爱心'
elif '#清除模式会话' == question or '#4' == question:
self.conversation_list[wxid][self.chat_type[wxid]
] = self.system_content_msg[self.chat_type[wxid]]
return '已清除'
elif '#清除全部会话' == question or '#5' == question:
self.conversation_list[wxid] = self.system_content_msg
return '已清除'
self.updateMessage(wxid, question, "user")
try:
params = dict(model="chatglm3", temperature=1.0,
messages=self.conversation_list[wxid][self.chat_type[wxid]], stream=False)
if 'tool' == self.chat_type[wxid]:
params["tools"] = [dict(type='function', function=d) for d in functions.values()]
response = self.client.chat.completions.create(**params)
for _ in range(self.max_retry):
if response.choices[0].message.get("function_call"):
function_call = response.choices[0].message.function_call
logger.debug(
f"Function Call Response: {function_call.to_dict_recursive()}")
function_args = json.loads(function_call.arguments)
observation = dispatch_tool(
function_call.name, function_args)
if isinstance(observation, dict):
res_type = observation['res_type'] if 'res_type' in observation else 'text'
res = observation['res'] if 'res_type' in observation else str(
observation)
if res_type == 'image':
filename = observation['filename']
filePath = os.path.join(self.filePath, filename)
res.save(filePath)
self.wcf and self.wcf.send_image(filePath, wxid)
tool_response = '[Image]' if res_type == 'image' else res
else:
tool_response = observation if isinstance(
observation, str) else str(observation)
logger.debug(f"Tool Call Response: {tool_response}")
params["messages"].append(response.choices[0].message)
params["messages"].append(
{
"role": "function",
"name": function_call.name,
"content": tool_response, # 调用函数返回结果
}
)
self.updateMessage(wxid, tool_response, "function")
response = self.client.chat.completions.create(**params)
elif response.choices[0].message.content.find('interpreter') != -1:
output_text = response.choices[0].message.content
code = extract_code(output_text)
self.wcf and self.wcf.send_text('代码如下:\n' + code, wxid)
self.wcf and self.wcf.send_text('执行代码...', wxid)
try:
res_type, res = execute(code, self.kernel)
except Exception as e:
rsp = f'代码执行错误: {e}'
break
if res_type == 'image':
filename = '{}.png'.format(''.join(random.sample(
'abcdefghijklmnopqrstuvwxyz1234567890', 8)))
filePath = os.path.join(self.filePath, filename)
res.save(filePath)
self.wcf and self.wcf.send_image(filePath, wxid)
else:
self.wcf and self.wcf.send_text("执行结果:\n" + res, wxid)
tool_response = '[Image]' if res_type == 'image' else res
logger.debug("Received: %s %s", res_type, res)
params["messages"].append(response.choices[0].message)
params["messages"].append(
{
"role": "function",
"name": "interpreter",
"content": tool_response, # 调用函数返回结果
}
)
self.updateMessage(wxid, tool_response, "function")
response = self.client.chat.completions.create(**params)
else:
rsp = response.choices[0].message.content
break
self.updateMessage(wxid, rsp, "assistant")
except Exception as e0:
rsp = "发生未知错误:" + str(e0)
return rsp
def updateMessage(self, wxid: str, question: str, role: str) -> None:
now_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
# 初始化聊天记录,组装系统信息
if wxid not in self.conversation_list.keys():
self.conversation_list[wxid] = self.system_content_msg
if wxid not in self.chat_type.keys():
self.chat_type[wxid] = 'chat'
# 当前问题
content_question_ = {"role": role, "content": question}
self.conversation_list[wxid][self.chat_type[wxid]].append(
content_question_)
# 只存储10条记录超过滚动清除
i = len(self.conversation_list[wxid][self.chat_type[wxid]])
if i > 10:
logger.info("滚动清除微信记录:%s", wxid)
# 删除多余的记录,倒着删,且跳过第一个的系统消息
del self.conversation_list[wxid][self.chat_type[wxid]][1]
if __name__ == "__main__":
from configuration import Config
config = Config().CHATGLM
if not config:
exit(0)
chat = ChatGLM(config)
while True:
q = input(">>> ")
try:
time_start = datetime.now() # 记录开始时间
logger.info(chat.get_answer(q, "wxid"))
time_end = datetime.now() # 记录结束时间
# 计算的时间差为程序的执行时间,单位为秒/s
logger.info(f"{round((time_end - time_start).total_seconds(), 2)}s")
except Exception as e:
logger.error("错误: %s", str(e), exc_info=True)