mirror of
https://github.com/Zippland/Bubbles.git
synced 2026-01-19 01:21:15 +08:00
更新配置文件,移除不再支持的模型配置,调整默认模型设置,并添加最大历史消息数配置。同时,重构代码以提高可读性,删除冗余的模型处理逻辑,确保代码结构更加简洁。
This commit is contained in:
@@ -47,14 +47,8 @@ Bubbles 是一个功能丰富的微信机器人框架,基于 [wcferry](https:/
|
||||
#### 🤖 灵活的模型配置
|
||||
- 支持为不同的群聊和私聊设置不同的 AI 模型和 system prompt
|
||||
- OpenAI (ChatGPT)
|
||||
- Google Gemini
|
||||
- 智谱 AI (ChatGLM)
|
||||
- 科大讯飞星火大模型
|
||||
- 阿里云通义千问
|
||||
- TigerBot
|
||||
- DeepSeek
|
||||
- Perplexity
|
||||
- Ollama (本地部署的模型)
|
||||
|
||||
#### 🛠️ 丰富的命令系统
|
||||
- 强大的命令路由系统,让功能新增无比简单
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
#! /usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import google.generativeai as genai
|
||||
|
||||
|
||||
class BardAssistant:
|
||||
def __init__(self, conf: dict) -> None:
|
||||
self._api_key = conf["api_key"]
|
||||
self._model_name = conf["model_name"]
|
||||
self._prompt = conf['prompt']
|
||||
self._proxy = conf['proxy']
|
||||
|
||||
genai.configure(api_key=self._api_key)
|
||||
self._bard = genai.GenerativeModel(self._model_name)
|
||||
|
||||
def __repr__(self):
|
||||
return 'BardAssistant'
|
||||
|
||||
@staticmethod
|
||||
def value_check(conf: dict) -> bool:
|
||||
if conf:
|
||||
if conf.get("api_key") and conf.get("model_name") and conf.get("prompt"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_answer(self, msg: str, sender: str = None) -> str:
|
||||
response = self._bard.generate_content([{'role': 'user', 'parts': [msg]}])
|
||||
return response.text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from configuration import Config
|
||||
config = Config().BardAssistant
|
||||
if not config:
|
||||
exit(0)
|
||||
|
||||
bard_assistant = BardAssistant(config)
|
||||
if bard_assistant._proxy:
|
||||
os.environ['HTTP_PROXY'] = bard_assistant._proxy
|
||||
os.environ['HTTPS_PROXY'] = bard_assistant._proxy
|
||||
rsp = bard_assistant.get_answer(bard_assistant._prompt)
|
||||
print(rsp)
|
||||
@@ -1,199 +0,0 @@
|
||||
#! /usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
from ai_providers.chatglm.code_kernel import CodeKernel, execute
|
||||
from ai_providers.chatglm.tool_registry import dispatch_tool, extract_code, get_tools
|
||||
from wcferry import Wcf
|
||||
|
||||
# 获取模块级 logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
functions = get_tools()
|
||||
|
||||
|
||||
class ChatGLM:
|
||||
|
||||
def __init__(self, config={}, wcf: Optional[Wcf] = None, max_retry=5) -> None:
|
||||
key = config.get("key", 'empty')
|
||||
api = config.get("api")
|
||||
proxy = config.get("proxy")
|
||||
if proxy:
|
||||
self.client = OpenAI(api_key=key, base_url=api, http_client=httpx.Client(proxy=proxy))
|
||||
else:
|
||||
self.client = OpenAI(api_key=key, base_url=api)
|
||||
self.conversation_list = {}
|
||||
self.chat_type = {}
|
||||
self.max_retry = max_retry
|
||||
self.wcf = wcf
|
||||
self.filePath = config["file_path"]
|
||||
self.kernel = CodeKernel()
|
||||
self.system_content_msg = {"chat": [{"role": "system", "content": config["prompt"]}],
|
||||
"tool": [{"role": "system",
|
||||
"content": "Answer the following questions as best as you can. You have access to the following tools:"}],
|
||||
"code": [{"role": "system",
|
||||
"content": "你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是{}。".format(
|
||||
self.filePath)}]}
|
||||
|
||||
def __repr__(self):
|
||||
return 'ChatGLM'
|
||||
|
||||
@staticmethod
|
||||
def value_check(conf: dict) -> bool:
|
||||
if conf:
|
||||
if conf.get("api") and conf.get("prompt") and conf.get("file_path"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_answer(self, question: str, wxid: str) -> str:
|
||||
# wxid或者roomid,个人时为微信id,群消息时为群id
|
||||
if '#帮助' == question:
|
||||
return '本助手有三种模式,#聊天模式 = #1 ,#工具模式 = #2 ,#代码模式 = #3 , #清除模式会话 = #4 , #清除全部会话 = #5 可用发送#对应模式 或者 #编号 进行切换'
|
||||
elif '#聊天模式' == question or '#1' == question:
|
||||
self.chat_type[wxid] = 'chat'
|
||||
return '已切换#聊天模式'
|
||||
elif '#工具模式' == question or '#2' == question:
|
||||
self.chat_type[wxid] = 'tool'
|
||||
return '已切换#工具模式 \n工具有:查看天气,日期,新闻,comfyUI文生图。例如:\n帮我生成一张小鸟的图片,提示词必须是英文'
|
||||
elif '#代码模式' == question or '#3' == question:
|
||||
self.chat_type[wxid] = 'code'
|
||||
return '已切换#代码模式 \n代码模式可以用于写python代码,例如:\n用python画一个爱心'
|
||||
elif '#清除模式会话' == question or '#4' == question:
|
||||
self.conversation_list[wxid][self.chat_type[wxid]
|
||||
] = self.system_content_msg[self.chat_type[wxid]]
|
||||
return '已清除'
|
||||
elif '#清除全部会话' == question or '#5' == question:
|
||||
self.conversation_list[wxid] = self.system_content_msg
|
||||
return '已清除'
|
||||
|
||||
self.updateMessage(wxid, question, "user")
|
||||
|
||||
try:
|
||||
params = dict(model="chatglm3", temperature=1.0,
|
||||
messages=self.conversation_list[wxid][self.chat_type[wxid]], stream=False)
|
||||
if 'tool' == self.chat_type[wxid]:
|
||||
params["tools"] = [dict(type='function', function=d) for d in functions.values()]
|
||||
response = self.client.chat.completions.create(**params)
|
||||
for _ in range(self.max_retry):
|
||||
if response.choices[0].message.get("function_call"):
|
||||
function_call = response.choices[0].message.function_call
|
||||
logger.debug(
|
||||
f"Function Call Response: {function_call.to_dict_recursive()}")
|
||||
|
||||
function_args = json.loads(function_call.arguments)
|
||||
observation = dispatch_tool(
|
||||
function_call.name, function_args)
|
||||
if isinstance(observation, dict):
|
||||
res_type = observation['res_type'] if 'res_type' in observation else 'text'
|
||||
res = observation['res'] if 'res_type' in observation else str(
|
||||
observation)
|
||||
if res_type == 'image':
|
||||
filename = observation['filename']
|
||||
filePath = os.path.join(self.filePath, filename)
|
||||
res.save(filePath)
|
||||
self.wcf and self.wcf.send_image(filePath, wxid)
|
||||
tool_response = '[Image]' if res_type == 'image' else res
|
||||
else:
|
||||
tool_response = observation if isinstance(
|
||||
observation, str) else str(observation)
|
||||
logger.debug(f"Tool Call Response: {tool_response}")
|
||||
|
||||
params["messages"].append(response.choices[0].message)
|
||||
params["messages"].append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": function_call.name,
|
||||
"content": tool_response, # 调用函数返回结果
|
||||
}
|
||||
)
|
||||
self.updateMessage(wxid, tool_response, "function")
|
||||
response = self.client.chat.completions.create(**params)
|
||||
elif response.choices[0].message.content.find('interpreter') != -1:
|
||||
output_text = response.choices[0].message.content
|
||||
code = extract_code(output_text)
|
||||
self.wcf and self.wcf.send_text('代码如下:\n' + code, wxid)
|
||||
self.wcf and self.wcf.send_text('执行代码...', wxid)
|
||||
try:
|
||||
res_type, res = execute(code, self.kernel)
|
||||
except Exception as e:
|
||||
rsp = f'代码执行错误: {e}'
|
||||
break
|
||||
if res_type == 'image':
|
||||
filename = '{}.png'.format(''.join(random.sample(
|
||||
'abcdefghijklmnopqrstuvwxyz1234567890', 8)))
|
||||
filePath = os.path.join(self.filePath, filename)
|
||||
res.save(filePath)
|
||||
self.wcf and self.wcf.send_image(filePath, wxid)
|
||||
else:
|
||||
self.wcf and self.wcf.send_text("执行结果:\n" + res, wxid)
|
||||
tool_response = '[Image]' if res_type == 'image' else res
|
||||
logger.debug("Received: %s %s", res_type, res)
|
||||
params["messages"].append(response.choices[0].message)
|
||||
params["messages"].append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": "interpreter",
|
||||
"content": tool_response, # 调用函数返回结果
|
||||
}
|
||||
)
|
||||
self.updateMessage(wxid, tool_response, "function")
|
||||
response = self.client.chat.completions.create(**params)
|
||||
else:
|
||||
rsp = response.choices[0].message.content
|
||||
break
|
||||
|
||||
self.updateMessage(wxid, rsp, "assistant")
|
||||
except Exception as e0:
|
||||
rsp = "发生未知错误:" + str(e0)
|
||||
|
||||
return rsp
|
||||
|
||||
def updateMessage(self, wxid: str, question: str, role: str) -> None:
|
||||
now_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
# 初始化聊天记录,组装系统信息
|
||||
if wxid not in self.conversation_list.keys():
|
||||
self.conversation_list[wxid] = self.system_content_msg
|
||||
if wxid not in self.chat_type.keys():
|
||||
self.chat_type[wxid] = 'chat'
|
||||
|
||||
# 当前问题
|
||||
content_question_ = {"role": role, "content": question}
|
||||
self.conversation_list[wxid][self.chat_type[wxid]].append(
|
||||
content_question_)
|
||||
|
||||
# 只存储10条记录,超过滚动清除
|
||||
i = len(self.conversation_list[wxid][self.chat_type[wxid]])
|
||||
if i > 10:
|
||||
logger.info("滚动清除微信记录:%s", wxid)
|
||||
# 删除多余的记录,倒着删,且跳过第一个的系统消息
|
||||
del self.conversation_list[wxid][self.chat_type[wxid]][1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from configuration import Config
|
||||
|
||||
config = Config().CHATGLM
|
||||
if not config:
|
||||
exit(0)
|
||||
|
||||
chat = ChatGLM(config)
|
||||
|
||||
while True:
|
||||
q = input(">>> ")
|
||||
try:
|
||||
time_start = datetime.now() # 记录开始时间
|
||||
logger.info(chat.get_answer(q, "wxid"))
|
||||
time_end = datetime.now() # 记录结束时间
|
||||
|
||||
# 计算的时间差为程序的执行时间,单位为秒/s
|
||||
logger.info(f"{round((time_end - time_start).total_seconds(), 2)}s")
|
||||
except Exception as e:
|
||||
logger.error("错误: %s", str(e), exc_info=True)
|
||||
@@ -1,3 +1,4 @@
|
||||
# ai_providers/ai_chatgpt.py
|
||||
#! /usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
@@ -5,119 +6,145 @@ import logging
|
||||
import base64
|
||||
import os
|
||||
from datetime import datetime
|
||||
import time # 引入 time 模块
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, APIError, AuthenticationError, OpenAI
|
||||
|
||||
# 引入 MessageSummary 类型提示 (如果需要更严格的类型检查)
|
||||
try:
|
||||
from function.func_summary import MessageSummary
|
||||
except ImportError:
|
||||
MessageSummary = object # Fallback if import fails or for simplified typing
|
||||
|
||||
|
||||
class ChatGPT():
|
||||
def __init__(self, conf: dict) -> None:
|
||||
# ---- 修改 __init__ ----
|
||||
def __init__(self, conf: dict, message_summary_instance: MessageSummary = None, bot_wxid: str = None) -> None:
|
||||
key = conf.get("key")
|
||||
api = conf.get("api")
|
||||
proxy = conf.get("proxy")
|
||||
prompt = conf.get("prompt")
|
||||
self.model = conf.get("model", "gpt-3.5-turbo")
|
||||
# ---- 新增:读取最大历史消息数配置 ----
|
||||
self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置,默认10条
|
||||
# ---- 新增结束 ----
|
||||
self.LOG = logging.getLogger("ChatGPT")
|
||||
|
||||
# ---- 存储传入的实例和wxid ----
|
||||
self.message_summary = message_summary_instance
|
||||
self.bot_wxid = bot_wxid
|
||||
if not self.message_summary:
|
||||
self.LOG.warning("MessageSummary 实例未提供给 ChatGPT,上下文功能将不可用!")
|
||||
if not self.bot_wxid:
|
||||
self.LOG.warning("bot_wxid 未提供给 ChatGPT,可能无法正确识别机器人自身消息!")
|
||||
# ---- 存储结束 ----
|
||||
|
||||
if proxy:
|
||||
self.client = OpenAI(api_key=key, base_url=api, http_client=httpx.Client(proxy=proxy))
|
||||
else:
|
||||
self.client = OpenAI(api_key=key, base_url=api)
|
||||
self.conversation_list = {}
|
||||
self.system_content_msg = {"role": "system", "content": prompt}
|
||||
# 确认是否使用支持视觉的模型
|
||||
|
||||
# ---- 移除 self.conversation_list ----
|
||||
# self.conversation_list = {}
|
||||
# ---- 移除结束 ----
|
||||
|
||||
self.system_content_msg = {"role": "system", "content": prompt if prompt else "You are a helpful assistant."} # 提供默认值
|
||||
self.support_vision = self.model == "gpt-4-vision-preview" or self.model == "gpt-4o" or "-vision" in self.model
|
||||
# ---- __init__ 修改结束 ----
|
||||
|
||||
def __repr__(self):
|
||||
return 'ChatGPT'
|
||||
|
||||
@staticmethod
|
||||
def value_check(conf: dict) -> bool:
|
||||
# 不再检查 prompt,因为可以没有默认 prompt
|
||||
if conf:
|
||||
if conf.get("key") and conf.get("api") and conf.get("prompt"):
|
||||
# ---- 修改:也检查 max_history_messages (虽然有默认值) ----
|
||||
if conf.get("key") and conf.get("api"): # and conf.get("max_history_messages") is not None: # 如果需要强制配置
|
||||
return True
|
||||
return False
|
||||
|
||||
# ---- 修改 get_answer ----
|
||||
def get_answer(self, question: str, wxid: str, system_prompt_override=None) -> str:
|
||||
# wxid或者roomid,个人时为微信id,群消息时为群id
|
||||
|
||||
# 检查是否是新对话
|
||||
is_new_conversation = wxid not in self.conversation_list
|
||||
|
||||
# 保存临时系统提示的状态
|
||||
temp_system_used = False
|
||||
original_prompt = None
|
||||
|
||||
if system_prompt_override:
|
||||
# 只有新对话才临时修改系统提示
|
||||
if is_new_conversation:
|
||||
# 临时保存原始系统提示,以便可以恢复
|
||||
original_prompt = self.system_content_msg["content"]
|
||||
# 设置临时系统提示
|
||||
self.system_content_msg["content"] = system_prompt_override
|
||||
temp_system_used = True
|
||||
self.LOG.debug(f"为新对话 {wxid} 临时设置系统提示")
|
||||
else:
|
||||
# 对于已存在的对话,我们将在API调用时临时使用覆盖提示,而不修改对话历史
|
||||
self.LOG.debug(f"对话 {wxid} 已存在,系统提示覆盖将仅用于本次API调用")
|
||||
|
||||
# 添加用户问题到对话历史
|
||||
self.updateMessage(wxid, question, "user")
|
||||
|
||||
# 如果修改了系统提示,现在恢复它
|
||||
if temp_system_used and original_prompt is not None:
|
||||
self.system_content_msg["content"] = original_prompt
|
||||
self.LOG.debug(f"已恢复默认系统提示")
|
||||
|
||||
# ---- 移除 #清除对话 逻辑 ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
# ---- 获取并格式化数据库历史记录 ----
|
||||
api_messages = []
|
||||
|
||||
# 1. 添加系统提示
|
||||
effective_system_prompt = system_prompt_override if system_prompt_override else self.system_content_msg["content"]
|
||||
if effective_system_prompt: # 确保有内容才添加
|
||||
api_messages.append({"role": "system", "content": effective_system_prompt})
|
||||
|
||||
# 添加当前时间提示(可选,但原代码有)
|
||||
now_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
time_mk = "Current time is: " # 或者其他合适的提示
|
||||
api_messages.append({"role": "system", "content": f"{time_mk}{now_time}"})
|
||||
|
||||
|
||||
# 2. 获取并格式化历史消息
|
||||
if self.message_summary and self.bot_wxid:
|
||||
history = self.message_summary.get_messages(wxid)
|
||||
|
||||
# ---- 新增:限制历史消息数量 ----
|
||||
if self.max_history_messages is not None and self.max_history_messages > 0:
|
||||
history = history[-self.max_history_messages:] # 取最新的 N 条
|
||||
elif self.max_history_messages == 0: # 如果设置为0,则不包含历史
|
||||
history = []
|
||||
# ---- 新增结束 ----
|
||||
|
||||
for msg in history:
|
||||
role = "assistant" if msg.get("sender_wxid") == self.bot_wxid else "user"
|
||||
formatted_content = msg.get('content', '')
|
||||
if formatted_content: # 避免添加空内容
|
||||
api_messages.append({"role": role, "content": formatted_content})
|
||||
else:
|
||||
self.LOG.warning(f"无法为 wxid={wxid} 获取历史记录,因为 message_summary 或 bot_wxid 未设置。")
|
||||
|
||||
# 3. 添加当前用户问题
|
||||
if question: # 确保问题非空
|
||||
api_messages.append({"role": "user", "content": question})
|
||||
# ---- 获取和格式化结束 ----
|
||||
|
||||
rsp = ""
|
||||
try:
|
||||
# 准备API调用的消息列表
|
||||
api_messages = []
|
||||
|
||||
# 对于已存在的对话,临时应用系统提示覆盖(如果有)
|
||||
if not is_new_conversation and system_prompt_override:
|
||||
# 第一个消息可能是系统提示
|
||||
has_system = self.conversation_list[wxid][0]["role"] == "system"
|
||||
|
||||
# 使用临时系统提示替代原始系统提示
|
||||
if has_system:
|
||||
# 复制除了系统提示外的所有消息
|
||||
api_messages = [{"role": "system", "content": system_prompt_override}]
|
||||
api_messages.extend(self.conversation_list[wxid][1:])
|
||||
else:
|
||||
# 如果没有系统提示,添加一个
|
||||
api_messages = [{"role": "system", "content": system_prompt_override}]
|
||||
api_messages.extend(self.conversation_list[wxid])
|
||||
else:
|
||||
# 对于新对话或没有临时系统提示的情况,使用原始对话历史
|
||||
api_messages = self.conversation_list[wxid]
|
||||
|
||||
# o系列模型不支持自定义temperature,只能使用默认值1
|
||||
# ---- 使用格式化后的 api_messages ----
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": api_messages
|
||||
"messages": api_messages # 使用从数据库构建的消息列表
|
||||
}
|
||||
|
||||
|
||||
# 只有非o系列模型才设置temperature
|
||||
if not self.model.startswith("o"):
|
||||
params["temperature"] = 0.2
|
||||
|
||||
|
||||
ret = self.client.chat.completions.create(**params)
|
||||
rsp = ret.choices[0].message.content
|
||||
rsp = rsp[2:] if rsp.startswith("\n\n") else rsp
|
||||
rsp = rsp.replace("\n\n", "\n")
|
||||
self.updateMessage(wxid, rsp, "assistant")
|
||||
|
||||
# ---- 移除 updateMessage 调用 ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
except AuthenticationError:
|
||||
self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确")
|
||||
rsp = "API认证失败"
|
||||
except APIConnectionError:
|
||||
self.LOG.error("无法连接到 OpenAI API,请检查网络连接")
|
||||
rsp = "网络连接错误"
|
||||
except APIError as e1:
|
||||
self.LOG.error(f"OpenAI API 返回了错误:{str(e1)}")
|
||||
rsp = "无法从 ChatGPT 获得答案"
|
||||
rsp = f"API错误: {str(e1)}"
|
||||
except Exception as e0:
|
||||
self.LOG.error(f"发生未知错误:{str(e0)}")
|
||||
rsp = "无法从 ChatGPT 获得答案"
|
||||
rsp = "发生未知错误"
|
||||
|
||||
return rsp
|
||||
# ---- get_answer 修改结束 ----
|
||||
|
||||
def encode_image_to_base64(self, image_path: str) -> str:
|
||||
"""将图片文件转换为Base64编码
|
||||
@@ -148,21 +175,21 @@ class ChatGPT():
|
||||
if not self.support_vision:
|
||||
self.LOG.error(f"当前模型 {self.model} 不支持图片理解,请使用gpt-4-vision-preview或gpt-4o")
|
||||
return "当前模型不支持图片理解功能,请联系管理员配置支持视觉的模型(如gpt-4-vision-preview或gpt-4o)"
|
||||
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
self.LOG.error(f"图片文件不存在: {image_path}")
|
||||
return "无法读取图片文件"
|
||||
|
||||
|
||||
try:
|
||||
base64_image = self.encode_image_to_base64(image_path)
|
||||
if not base64_image:
|
||||
return "图片编码失败"
|
||||
|
||||
# 构建带有图片的消息
|
||||
|
||||
# 构建带有图片的消息 (这里不使用历史记录)
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个图片分析专家,擅长分析图片内容并提供详细描述。"},
|
||||
{"role": "system", "content": "你是一个图片分析专家,擅长分析图片内容并提供详细描述。"}, # 可以使用 self.system_content_msg 如果适用
|
||||
{
|
||||
"role": "user",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
@@ -174,25 +201,23 @@ class ChatGPT():
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# 使用GPT-4 Vision模型
|
||||
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": 1000 # 限制输出长度
|
||||
"max_tokens": 1000
|
||||
}
|
||||
|
||||
# 支持视觉的模型可能有不同参数要求
|
||||
|
||||
if not self.model.startswith("o"):
|
||||
params["temperature"] = 0.7
|
||||
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
description = response.choices[0].message.content
|
||||
description = description[2:] if description.startswith("\n\n") else description
|
||||
description = description.replace("\n\n", "\n")
|
||||
|
||||
|
||||
return description
|
||||
|
||||
|
||||
except AuthenticationError:
|
||||
self.LOG.error("OpenAI API 认证失败,请检查 API 密钥是否正确")
|
||||
return "API认证失败,无法分析图片"
|
||||
@@ -206,63 +231,13 @@ class ChatGPT():
|
||||
self.LOG.error(f"分析图片时发生未知错误:{str(e0)}")
|
||||
return f"处理图片时出错:{str(e0)}"
|
||||
|
||||
def updateMessage(self, wxid: str, content: str, role: str) -> None:
|
||||
now_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
time_mk = "当需要回答时间时请直接参考回复:"
|
||||
# 初始化聊天记录,组装系统信息
|
||||
if wxid not in self.conversation_list.keys():
|
||||
# 此时self.system_content_msg可能已经被get_answer临时修改
|
||||
# 但这没关系,因为在get_answer结束前会恢复
|
||||
question_ = [
|
||||
self.system_content_msg,
|
||||
{"role": "system", "content": "" + time_mk + now_time}
|
||||
]
|
||||
self.conversation_list[wxid] = question_
|
||||
|
||||
# 当前问题或回答
|
||||
content_message = {"role": role, "content": content}
|
||||
self.conversation_list[wxid].append(content_message)
|
||||
|
||||
# 更新时间标记
|
||||
for cont in self.conversation_list[wxid]:
|
||||
if cont["role"] != "system":
|
||||
continue
|
||||
if cont["content"].startswith(time_mk):
|
||||
cont["content"] = time_mk + now_time
|
||||
|
||||
# 控制对话历史长度
|
||||
# 只存储10条记录,超过滚动清除
|
||||
max_history = 12 # 包括1个系统提示和1个时间标记
|
||||
i = len(self.conversation_list[wxid])
|
||||
if i > max_history:
|
||||
# 计算需要删除多少条记录
|
||||
if self.conversation_list[wxid][0]["role"] == "system" and self.conversation_list[wxid][1]["role"] == "system":
|
||||
# 如果前两条都是系统消息,保留它们,删除较早的用户和助手消息
|
||||
to_delete = i - max_history
|
||||
del self.conversation_list[wxid][2:2+to_delete]
|
||||
self.LOG.debug(f"滚动清除微信记录:{wxid},删除了{to_delete}条历史消息")
|
||||
else:
|
||||
# 如果结构不符合预期,简单地保留最近的消息
|
||||
self.conversation_list[wxid] = self.conversation_list[wxid][-max_history:]
|
||||
self.LOG.debug(f"滚动清除微信记录:{wxid},只保留最近{max_history}条消息")
|
||||
# ---- 移除 updateMessage ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from configuration import Config
|
||||
config = Config().CHATGPT
|
||||
if not config:
|
||||
exit(0)
|
||||
|
||||
chat = ChatGPT(config)
|
||||
|
||||
while True:
|
||||
q = input(">>> ")
|
||||
try:
|
||||
time_start = datetime.now() # 记录开始时间
|
||||
print(chat.get_answer(q, "wxid"))
|
||||
time_end = datetime.now() # 记录结束时间
|
||||
|
||||
print(f"{round((time_end - time_start).total_seconds(), 2)}s") # 计算的时间差为程序的执行时间,单位为秒/s
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# --- 测试代码需要调整 ---
|
||||
# 需要模拟 MessageSummary 和提供 bot_wxid 才能测试
|
||||
print("请注意:直接运行此文件进行测试需要模拟 MessageSummary 并提供 bot_wxid。")
|
||||
pass # 避免直接运行时出错
|
||||
@@ -1,164 +1,139 @@
|
||||
# ai_providers/ai_deepseek.py
|
||||
#! /usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import time # 引入 time 模块
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, APIError, AuthenticationError, OpenAI
|
||||
|
||||
# 引入 MessageSummary 类型提示
|
||||
try:
|
||||
from function.func_summary import MessageSummary
|
||||
except ImportError:
|
||||
MessageSummary = object
|
||||
|
||||
class DeepSeek():
|
||||
def __init__(self, conf: dict) -> None:
|
||||
# ---- 修改 __init__ ----
|
||||
def __init__(self, conf: dict, message_summary_instance: MessageSummary = None, bot_wxid: str = None) -> None:
|
||||
key = conf.get("key")
|
||||
api = conf.get("api", "https://api.deepseek.com")
|
||||
proxy = conf.get("proxy")
|
||||
prompt = conf.get("prompt")
|
||||
self.model = conf.get("model", "deepseek-chat")
|
||||
# ---- 新增:读取最大历史消息数配置 ----
|
||||
self.max_history_messages = conf.get("max_history_messages", 10) # 读取配置,默认10条
|
||||
# ---- 新增结束 ----
|
||||
self.LOG = logging.getLogger("DeepSeek")
|
||||
|
||||
self.reasoning_supported = (self.model == "deepseek-reasoner")
|
||||
|
||||
if conf.get("enable_reasoning", False) and not self.reasoning_supported:
|
||||
self.LOG.warning("思维链功能只在使用 deepseek-reasoner 模型时可用,当前模型不支持此功能")
|
||||
|
||||
self.enable_reasoning = conf.get("enable_reasoning", False) and self.reasoning_supported
|
||||
self.show_reasoning = conf.get("show_reasoning", False) and self.enable_reasoning
|
||||
|
||||
|
||||
# ---- 存储传入的实例和wxid ----
|
||||
self.message_summary = message_summary_instance
|
||||
self.bot_wxid = bot_wxid
|
||||
if not self.message_summary:
|
||||
self.LOG.warning("MessageSummary 实例未提供给 DeepSeek,上下文功能将不可用!")
|
||||
if not self.bot_wxid:
|
||||
self.LOG.warning("bot_wxid 未提供给 DeepSeek,可能无法正确识别机器人自身消息!")
|
||||
# ---- 存储结束 ----
|
||||
|
||||
# ---- 移除思维链相关逻辑 ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
if proxy:
|
||||
self.client = OpenAI(api_key=key, base_url=api, http_client=httpx.Client(proxy=proxy))
|
||||
else:
|
||||
self.client = OpenAI(api_key=key, base_url=api)
|
||||
|
||||
self.conversation_list = {}
|
||||
|
||||
self.system_content_msg = {"role": "system", "content": prompt}
|
||||
|
||||
|
||||
# ---- 移除 self.conversation_list ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
self.system_content_msg = {"role": "system", "content": prompt if prompt else "You are a helpful assistant."} # 提供默认值
|
||||
# ---- __init__ 修改结束 ----
|
||||
|
||||
def __repr__(self):
|
||||
return 'DeepSeek'
|
||||
|
||||
@staticmethod
|
||||
def value_check(conf: dict) -> bool:
|
||||
if conf:
|
||||
if conf.get("key") and conf.get("prompt"):
|
||||
# ---- 修改:也检查 max_history_messages (虽然有默认值) ----
|
||||
if conf.get("key"): # and conf.get("max_history_messages") is not None: # 如果需要强制配置
|
||||
return True
|
||||
return False
|
||||
|
||||
# ---- 修改 get_answer ----
|
||||
def get_answer(self, question: str, wxid: str, system_prompt_override=None) -> str:
|
||||
if question == "#清除对话":
|
||||
if wxid in self.conversation_list.keys():
|
||||
del self.conversation_list[wxid]
|
||||
return "已清除上下文"
|
||||
|
||||
if question.lower() in ["#开启思维链", "#enable reasoning"]:
|
||||
if not self.reasoning_supported:
|
||||
return "当前模型不支持思维链功能,请使用 deepseek-reasoner 模型"
|
||||
self.enable_reasoning = True
|
||||
self.show_reasoning = True
|
||||
return "已开启思维链模式,将显示完整的推理过程"
|
||||
|
||||
if question.lower() in ["#关闭思维链", "#disable reasoning"]:
|
||||
if not self.reasoning_supported:
|
||||
return "当前模型不支持思维链功能,无需关闭"
|
||||
self.enable_reasoning = False
|
||||
self.show_reasoning = False
|
||||
return "已关闭思维链模式"
|
||||
|
||||
if question.lower() in ["#隐藏思维链", "#hide reasoning"]:
|
||||
if not self.enable_reasoning:
|
||||
return "思维链功能未开启,无法设置隐藏/显示"
|
||||
self.show_reasoning = False
|
||||
return "已设置隐藏思维链,但模型仍会进行深度思考"
|
||||
|
||||
if question.lower() in ["#显示思维链", "#show reasoning"]:
|
||||
if not self.enable_reasoning:
|
||||
return "思维链功能未开启,无法设置隐藏/显示"
|
||||
self.show_reasoning = True
|
||||
return "已设置显示思维链"
|
||||
|
||||
# 初始化对话历史(只在首次时添加系统提示)
|
||||
if wxid not in self.conversation_list:
|
||||
self.conversation_list[wxid] = []
|
||||
# 只有在这里才添加默认的系统提示到对话历史中
|
||||
if self.system_content_msg["content"]:
|
||||
self.conversation_list[wxid].append(self.system_content_msg)
|
||||
|
||||
# 添加用户问题到对话历史
|
||||
self.conversation_list[wxid].append({"role": "user", "content": question})
|
||||
# ---- 移除 #清除对话 和 思维链命令 ----
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
# ---- 获取并格式化数据库历史记录 ----
|
||||
api_messages = []
|
||||
|
||||
# 1. 添加系统提示
|
||||
effective_system_prompt = system_prompt_override if system_prompt_override else self.system_content_msg["content"]
|
||||
if effective_system_prompt:
|
||||
api_messages.append({"role": "system", "content": effective_system_prompt})
|
||||
|
||||
# 添加当前时间提示 (可选)
|
||||
now_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
time_mk = "Current time is: "
|
||||
api_messages.append({"role": "system", "content": f"{time_mk}{now_time}"})
|
||||
|
||||
|
||||
# 2. 获取并格式化历史消息
|
||||
if self.message_summary and self.bot_wxid:
|
||||
history = self.message_summary.get_messages(wxid)
|
||||
|
||||
# ---- 新增:限制历史消息数量 ----
|
||||
if self.max_history_messages is not None and self.max_history_messages > 0:
|
||||
history = history[-self.max_history_messages:] # 取最新的 N 条
|
||||
elif self.max_history_messages == 0: # 如果设置为0,则不包含历史
|
||||
history = []
|
||||
# ---- 新增结束 ----
|
||||
|
||||
for msg in history:
|
||||
role = "assistant" if msg.get("sender_wxid") == self.bot_wxid else "user"
|
||||
formatted_content = msg.get('content', '')
|
||||
if formatted_content:
|
||||
api_messages.append({"role": role, "content": formatted_content})
|
||||
else:
|
||||
self.LOG.warning(f"无法为 wxid={wxid} 获取历史记录,因为 message_summary 或 bot_wxid 未设置。")
|
||||
|
||||
# 3. 添加当前用户问题
|
||||
if question:
|
||||
api_messages.append({"role": "user", "content": question})
|
||||
# ---- 获取和格式化结束 ----
|
||||
|
||||
try:
|
||||
# 准备API调用的消息列表
|
||||
api_messages = []
|
||||
|
||||
# 检查是否需要使用临时系统提示
|
||||
if system_prompt_override:
|
||||
# 如果提供了临时系统提示,在API调用时使用它(不修改对话历史)
|
||||
api_messages.append({"role": "system", "content": system_prompt_override})
|
||||
# 添加除了系统提示外的所有历史消息
|
||||
for msg in self.conversation_list[wxid]:
|
||||
if msg["role"] != "system":
|
||||
api_messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
else:
|
||||
# 如果没有临时系统提示,使用完整的对话历史
|
||||
for msg in self.conversation_list[wxid]:
|
||||
api_messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
|
||||
# ---- 使用格式化后的 api_messages ----
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
messages=api_messages, # 使用构建的消息列表
|
||||
stream=False
|
||||
)
|
||||
|
||||
if self.reasoning_supported and self.enable_reasoning:
|
||||
# deepseek-reasoner模型返回的特殊字段: reasoning_content和content
|
||||
# 单独处理思维链模式的响应
|
||||
reasoning_content = getattr(response.choices[0].message, "reasoning_content", None)
|
||||
content = response.choices[0].message.content
|
||||
# ---- 移除思维链特殊处理和本地历史更新 ----
|
||||
# ... (代码已移除) ...
|
||||
final_response = response.choices[0].message.content
|
||||
# ... (代码已移除) ...
|
||||
# ---- 移除结束 ----
|
||||
|
||||
if self.show_reasoning and reasoning_content:
|
||||
final_response = f"🤔思考过程:\n{reasoning_content}\n\n🎉最终答案:\n{content}"
|
||||
#最好不要删除表情,因为微信内的信息没有办法做自定义显示,这里是为了做两个分隔,来区分思考过程和最终答案!💡
|
||||
else:
|
||||
final_response = content
|
||||
self.conversation_list[wxid].append({"role": "assistant", "content": content})
|
||||
else:
|
||||
final_response = response.choices[0].message.content
|
||||
self.conversation_list[wxid].append({"role": "assistant", "content": final_response})
|
||||
|
||||
# 控制对话长度,保留最近的历史记录
|
||||
# 系统消息(如果有) + 最近9轮对话(问答各算一轮)
|
||||
max_history = 11
|
||||
if len(self.conversation_list[wxid]) > max_history:
|
||||
has_system = self.conversation_list[wxid][0]["role"] == "system"
|
||||
if has_system:
|
||||
self.conversation_list[wxid] = [self.conversation_list[wxid][0]] + self.conversation_list[wxid][-(max_history-1):]
|
||||
else:
|
||||
self.conversation_list[wxid] = self.conversation_list[wxid][-max_history:]
|
||||
|
||||
return final_response
|
||||
|
||||
|
||||
except (APIConnectionError, APIError, AuthenticationError) as e1:
|
||||
self.LOG.error(f"DeepSeek API 返回了错误:{str(e1)}")
|
||||
return f"DeepSeek API 返回了错误:{str(e1)}"
|
||||
except Exception as e0:
|
||||
self.LOG.error(f"发生未知错误:{str(e0)}")
|
||||
return "抱歉,处理您的请求时出现了错误"
|
||||
# ---- get_answer 修改结束 ----
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from configuration import Config
|
||||
config = Config().DEEPSEEK
|
||||
if not config:
|
||||
exit(0)
|
||||
|
||||
chat = DeepSeek(config)
|
||||
|
||||
while True:
|
||||
q = input(">>> ")
|
||||
try:
|
||||
time_start = datetime.now()
|
||||
print(chat.get_answer(q, "wxid"))
|
||||
time_end = datetime.now()
|
||||
print(f"{round((time_end - time_start).total_seconds(), 2)}s")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# --- 测试代码需要调整 ---
|
||||
print("请注意:直接运行此文件进行测试需要模拟 MessageSummary 并提供 bot_wxid。")
|
||||
pass
|
||||
@@ -1,81 +0,0 @@
|
||||
#! /usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
import ollama
|
||||
|
||||
|
||||
class Ollama():
|
||||
def __init__(self, conf: dict) -> None:
|
||||
enable = conf.get("enable")
|
||||
self.model = conf.get("model")
|
||||
self.prompt = conf.get("prompt")
|
||||
|
||||
self.LOG = logging.getLogger("Ollama")
|
||||
self.conversation_list = {}
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return 'Ollama'
|
||||
|
||||
@staticmethod
|
||||
def value_check(conf: dict) -> bool:
|
||||
if conf:
|
||||
if conf.get("enable") and conf.get("model") and conf.get("prompt"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_answer(self, question: str, wxid: str) -> str:
|
||||
try:
|
||||
self.conversation_list[wxid]
|
||||
except KeyError:
|
||||
res=ollama.generate(model=self.model, prompt=self.prompt, keep_alive="30m")
|
||||
self.updateMessage(wxid, res["context"], "assistant")
|
||||
# wxid或者roomid,个人时为微信id,群消息时为群id
|
||||
rsp = ""
|
||||
try:
|
||||
res=ollama.generate(model=self.model, prompt=question, context=self.conversation_list[wxid], keep_alive="30m")
|
||||
self.updateMessage(wxid, res["context"], "user")
|
||||
res_message = res["response"]
|
||||
# 去除<think>标签对与内部内容
|
||||
# res_message = res_message.split("</think>")[-1]
|
||||
# 去除开头的\n和空格
|
||||
# return res_message[2:]
|
||||
return res_message
|
||||
except Exception as e0:
|
||||
self.LOG.error(f"发生未知错误:{str(e0)}", exc_info=True)
|
||||
|
||||
return rsp
|
||||
|
||||
def updateMessage(self, wxid: str, context: str, role: str) -> None:
|
||||
# 当前问题
|
||||
self.conversation_list[wxid] = context
|
||||
|
||||
if __name__ == "__main__":
|
||||
from configuration import Config
|
||||
# 设置测试用的日志配置
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s'
|
||||
)
|
||||
|
||||
config = Config().OLLAMA
|
||||
if not config:
|
||||
exit(0)
|
||||
|
||||
chat = Ollama(config)
|
||||
|
||||
while True:
|
||||
q = input(">>> ")
|
||||
try:
|
||||
time_start = datetime.now() # 记录开始时间
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(chat.get_answer(q, "wxid"))
|
||||
time_end = datetime.now() # 记录结束时间
|
||||
|
||||
logger.info(f"{round((time_end - time_start).total_seconds(), 2)}s") # 计算的时间差为程序的执行时间,单位为秒/s
|
||||
except Exception as e:
|
||||
logger.error(f"错误: {e}", exc_info=True)
|
||||
@@ -1,49 +0,0 @@
|
||||
#! /usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from random import randint
|
||||
|
||||
|
||||
class TigerBot:
|
||||
def __init__(self, tbconf=None) -> None:
|
||||
self.LOG = logging.getLogger(__file__)
|
||||
self.tburl = "https://api.tigerbot.com/bot-service/ai_service/gpt"
|
||||
self.tbheaders = {"Authorization": "Bearer " + tbconf["key"]}
|
||||
self.tbmodel = tbconf["model"]
|
||||
self.fallback = ["滚", "快滚", "赶紧滚"]
|
||||
|
||||
def __repr__(self):
|
||||
return 'TigerBot'
|
||||
|
||||
@staticmethod
|
||||
def value_check(conf: dict) -> bool:
|
||||
if conf:
|
||||
return all(conf.values())
|
||||
return False
|
||||
|
||||
def get_answer(self, msg: str, sender: str = None) -> str:
|
||||
payload = {
|
||||
"text": msg,
|
||||
"modelVersion": self.tbmodel
|
||||
}
|
||||
rsp = ""
|
||||
try:
|
||||
rsp = requests.post(self.tburl, headers=self.tbheaders, json=payload).json()
|
||||
rsp = rsp["data"]["result"][0]
|
||||
except Exception as e:
|
||||
self.LOG.error(f"{e}: {payload}\n{rsp}")
|
||||
idx = randint(0, len(self.fallback) - 1)
|
||||
rsp = self.fallback[idx]
|
||||
|
||||
return rsp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from configuration import Config
|
||||
c = Config()
|
||||
tbot = TigerBot(c.TIGERBOT)
|
||||
rsp = tbot.get_answer("你还活着?")
|
||||
print(rsp)
|
||||
@@ -1,38 +0,0 @@
|
||||
#! /usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from sparkdesk_web.core import SparkWeb
|
||||
|
||||
|
||||
class XinghuoWeb:
|
||||
def __init__(self, xhconf=None) -> None:
|
||||
|
||||
self._sparkWeb = SparkWeb(
|
||||
cookie=xhconf["cookie"],
|
||||
fd=xhconf["fd"],
|
||||
GtToken=xhconf["GtToken"],
|
||||
)
|
||||
self._chat = self._sparkWeb.create_continuous_chat()
|
||||
# 如果有提示词
|
||||
if xhconf["prompt"]:
|
||||
self._chat.chat(xhconf["prompt"])
|
||||
|
||||
def __repr__(self):
|
||||
return 'XinghuoWeb'
|
||||
|
||||
@staticmethod
|
||||
def value_check(conf: dict) -> bool:
|
||||
if conf:
|
||||
return all(conf.values())
|
||||
return False
|
||||
|
||||
def get_answer(self, msg: str, sender: str = None) -> str:
|
||||
answer = self._chat.chat(msg)
|
||||
return answer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from configuration import Config
|
||||
c = Config()
|
||||
xinghuo = XinghuoWeb(c.XINGHUO_WEB)
|
||||
rsp = xinghuo.get_answer("你还活着?")
|
||||
print(rsp)
|
||||
@@ -1,46 +0,0 @@
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
|
||||
class ZhiPu():
|
||||
def __init__(self, conf: dict) -> None:
|
||||
self.api_key = conf.get("api_key")
|
||||
self.model = conf.get("model", "glm-4") # 默认使用 glm-4 模型
|
||||
self.client = ZhipuAI(api_key=self.api_key)
|
||||
self.converstion_list = {}
|
||||
|
||||
@staticmethod
|
||||
def value_check(conf: dict) -> bool:
|
||||
if conf and conf.get("api_key"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return 'ZhiPu'
|
||||
|
||||
def get_answer(self, msg: str, wxid: str, **args) -> str:
|
||||
self._update_message(wxid, str(msg), "user")
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=self.converstion_list[wxid]
|
||||
)
|
||||
resp_msg = response.choices[0].message
|
||||
answer = resp_msg.content
|
||||
self._update_message(wxid, answer, "assistant")
|
||||
return answer
|
||||
|
||||
def _update_message(self, wxid: str, msg: str, role: str) -> None:
|
||||
if wxid not in self.converstion_list.keys():
|
||||
self.converstion_list[wxid] = []
|
||||
content = {"role": role, "content": str(msg)}
|
||||
self.converstion_list[wxid].append(content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from configuration import Config
|
||||
config = Config().ZHIPU
|
||||
if not config:
|
||||
exit(0)
|
||||
|
||||
zhipu = ZhiPu(config)
|
||||
rsp = zhipu.get_answer("你好")
|
||||
print(rsp)
|
||||
@@ -1,45 +0,0 @@
|
||||
# ChatGLM3 集成使用说明
|
||||
|
||||
1. 需要取消配置中 chatglm 的注释, 并配置对应信息,使用 [ChatGLM3](https://github.com/THUDM/ChatGLM3), 启用最新版 ChatGLM3 根目录下 openai_api.py 获取 api 地址:
|
||||
```yaml
|
||||
# 如果要使用 chatglm,取消下面的注释并填写相关内容
|
||||
chatglm:
|
||||
key: sk-012345678901234567890123456789012345678901234567 # 根据需要自己做key校验
|
||||
api: http://localhost:8000/v1 # 根据自己的chatglm地址修改
|
||||
proxy: # 如果你在国内,你可能需要魔法,大概长这样:http://域名或者IP地址:端口号
|
||||
prompt: 你是智能聊天机器人,你叫小薇 # 根据需要对角色进行设定
|
||||
file_path: F:/Pictures/temp #设定生成图片和代码使用的文件夹路径
|
||||
```
|
||||
|
||||
2. 修改 chatglm/tool_registry.py 工具里面的一下配置,comfyUI 地址或者根据需要自己配置一些工具,函数名上需要加 @register_tool, 函数里面需要叫'''函数描述''',参数需要用 Annotated[str,'',True] 修饰,分别是类型,参数说明,是否必填,再加 ->加上对应的返回类型
|
||||
```python
|
||||
@register_tool
|
||||
def get_confyui_image(prompt: Annotated[str, '要生成图片的提示词,注意必须是英文', True]) -> dict:
|
||||
'''
|
||||
生成图片
|
||||
'''
|
||||
with open("func_chatglm\\base.json", "r", encoding="utf-8") as f:
|
||||
data2 = json.load(f)
|
||||
data2['prompt']['3']['inputs']['seed'] = ''.join(
|
||||
random.sample('123456789012345678901234567890', 14))
|
||||
# 模型名称
|
||||
data2['prompt']['4']['inputs']['ckpt_name'] = 'chilloutmix_NiPrunedFp32Fix.safetensors'
|
||||
data2['prompt']['6']['inputs']['text'] = prompt # 正向提示词
|
||||
# data2['prompt']['7']['inputs']['text']='' #反向提示词
|
||||
cfui = ComfyUIApi(server_address="127.0.0.1:8188") # 根据自己comfyUI地址修改
|
||||
images = cfui.get_images(data2['prompt'])
|
||||
return {'res': images[0]['image'], 'res_type': 'image', 'filename': images[0]['filename']}
|
||||
|
||||
```
|
||||
|
||||
3. 使用 Code Interpreter 还需要安装 Jupyter 内核,默认名称叫 chatglm3:
|
||||
```
|
||||
ipython kernel install --name chatglm3 --user
|
||||
```
|
||||
|
||||
如果名称需要自定义,可以配置系统环境变量:IPYKERNEL 或者修改 chatglm/code_kernel.py
|
||||
```
|
||||
IPYKERNEL = os.environ.get('IPYKERNEL', 'chatglm3')
|
||||
```
|
||||
|
||||
4. 启动后,发送 #帮助 可以查看 模式和常用指令
|
||||
@@ -1,13 +0,0 @@
|
||||
import sys
|
||||
|
||||
|
||||
class UnsupportedPythonVersionError(Exception):
|
||||
def __init__(self, error_msg: str):
|
||||
super().__init__(error_msg)
|
||||
|
||||
|
||||
python_version_info = sys.version_info
|
||||
if not sys.version_info >= (3, 9):
|
||||
msg = "当前Python版本: " + ".".join(map(str, python_version_info[:3])) + (', 需要python版本 >= 3.9, 前往下载: '
|
||||
'https://www.python.org/downloads/')
|
||||
raise UnsupportedPythonVersionError(msg)
|
||||
@@ -1,88 +0,0 @@
|
||||
{
|
||||
"prompt": {
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": 1000573256060686,
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler"
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"ckpt_name": "(修复)512-inpainting-ema.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple"
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage"
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,dress, ",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode"
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "text, watermark",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode"
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode"
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,203 +0,0 @@
|
||||
import base64
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from subprocess import PIPE
|
||||
from typing import Optional, Union
|
||||
|
||||
import jupyter_client
|
||||
from PIL import Image
|
||||
|
||||
# 获取模块级 logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IPYKERNEL = os.environ.get('IPYKERNEL', 'chatglm3')
|
||||
|
||||
|
||||
class CodeKernel(object):
|
||||
def __init__(self,
|
||||
kernel_name='kernel',
|
||||
kernel_id=None,
|
||||
kernel_config_path="",
|
||||
python_path=None,
|
||||
ipython_path=None,
|
||||
init_file_path="./startup.py",
|
||||
verbose=1):
|
||||
|
||||
self.kernel_name = kernel_name
|
||||
self.kernel_id = kernel_id
|
||||
self.kernel_config_path = kernel_config_path
|
||||
self.python_path = python_path
|
||||
self.ipython_path = ipython_path
|
||||
self.init_file_path = init_file_path
|
||||
self.verbose = verbose
|
||||
|
||||
if python_path is None and ipython_path is None:
|
||||
env = None
|
||||
else:
|
||||
env = {"PATH": self.python_path + ":$PATH",
|
||||
"PYTHONPATH": self.python_path}
|
||||
|
||||
# Initialize the backend kernel
|
||||
self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,
|
||||
connection_file=self.kernel_config_path,
|
||||
exec_files=[
|
||||
self.init_file_path],
|
||||
env=env)
|
||||
if self.kernel_config_path:
|
||||
self.kernel_manager.load_connection_file()
|
||||
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
|
||||
logger.info("Backend kernel started with the configuration: %s",
|
||||
self.kernel_config_path)
|
||||
else:
|
||||
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
|
||||
logger.info("Backend kernel started with the configuration: %s",
|
||||
self.kernel_manager.connection_file)
|
||||
|
||||
if verbose:
|
||||
logger.debug(self.kernel_manager.get_connection_info())
|
||||
|
||||
# Initialize the code kernel
|
||||
self.kernel = self.kernel_manager.blocking_client()
|
||||
# self.kernel.load_connection_file()
|
||||
self.kernel.start_channels()
|
||||
logger.info("Code kernel started.")
|
||||
|
||||
def execute(self, code):
|
||||
self.kernel.execute(code)
|
||||
try:
|
||||
shell_msg = self.kernel.get_shell_msg(timeout=40)
|
||||
io_msg_content = self.kernel.get_iopub_msg(timeout=40)['content']
|
||||
while True:
|
||||
msg_out = io_msg_content
|
||||
# Poll the message
|
||||
try:
|
||||
io_msg_content = self.kernel.get_iopub_msg(timeout=40)[
|
||||
'content']
|
||||
if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle':
|
||||
break
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
return shell_msg, msg_out
|
||||
except Exception as e:
|
||||
logger.error("执行代码时出错: %s", str(e), exc_info=True)
|
||||
return None
|
||||
|
||||
def execute_interactive(self, code, verbose=False):
|
||||
shell_msg = self.kernel.execute_interactive(code)
|
||||
if shell_msg is queue.Empty:
|
||||
if verbose:
|
||||
logger.warning("Timeout waiting for shell message.")
|
||||
self.check_msg(shell_msg, verbose=verbose)
|
||||
|
||||
return shell_msg
|
||||
|
||||
def inspect(self, code, verbose=False):
|
||||
msg_id = self.kernel.inspect(code)
|
||||
shell_msg = self.kernel.get_shell_msg(timeout=30)
|
||||
if shell_msg is queue.Empty:
|
||||
if verbose:
|
||||
logger.warning("Timeout waiting for shell message.")
|
||||
self.check_msg(shell_msg, verbose=verbose)
|
||||
|
||||
return shell_msg
|
||||
|
||||
def get_error_msg(self, msg, verbose=False) -> Optional[str]:
|
||||
if msg['content']['status'] == 'error':
|
||||
try:
|
||||
error_msg = msg['content']['traceback']
|
||||
except BaseException:
|
||||
try:
|
||||
error_msg = msg['content']['traceback'][-1].strip()
|
||||
except BaseException:
|
||||
error_msg = "Traceback Error"
|
||||
if verbose:
|
||||
logger.error("Error: %s", error_msg)
|
||||
return error_msg
|
||||
return None
|
||||
|
||||
def check_msg(self, msg, verbose=False):
|
||||
status = msg['content']['status']
|
||||
if status == 'ok':
|
||||
if verbose:
|
||||
logger.info("Execution succeeded.")
|
||||
elif status == 'error':
|
||||
for line in msg['content']['traceback']:
|
||||
if verbose:
|
||||
logger.error(line)
|
||||
|
||||
def shutdown(self):
|
||||
# Shutdown the backend kernel
|
||||
self.kernel_manager.shutdown_kernel()
|
||||
logger.info("Backend kernel shutdown.")
|
||||
# Shutdown the code kernel
|
||||
self.kernel.shutdown()
|
||||
logger.info("Code kernel shutdown.")
|
||||
|
||||
def restart(self):
|
||||
# Restart the backend kernel
|
||||
self.kernel_manager.restart_kernel()
|
||||
# logger.info("Backend kernel restarted.")
|
||||
|
||||
def interrupt(self):
|
||||
# Interrupt the backend kernel
|
||||
self.kernel_manager.interrupt_kernel()
|
||||
# logger.info("Backend kernel interrupted.")
|
||||
|
||||
def is_alive(self):
|
||||
return self.kernel.is_alive()
|
||||
|
||||
|
||||
def b64_2_img(data):
|
||||
buff = BytesIO(base64.b64decode(data))
|
||||
return Image.open(buff)
|
||||
|
||||
|
||||
def clean_ansi_codes(input_string):
|
||||
ansi_escape = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]')
|
||||
return ansi_escape.sub('', input_string)
|
||||
|
||||
|
||||
def execute(code, kernel: CodeKernel) -> tuple[str, Union[str, Image.Image]]:
|
||||
res = ""
|
||||
res_type = None
|
||||
code = code.replace("<|observation|>", "")
|
||||
code = code.replace("<|assistant|>interpreter", "")
|
||||
code = code.replace("<|assistant|>", "")
|
||||
code = code.replace("<|user|>", "")
|
||||
code = code.replace("<|system|>", "")
|
||||
msg, output = kernel.execute(code)
|
||||
|
||||
if msg['metadata']['status'] == "timeout":
|
||||
return res_type, 'Timed out'
|
||||
elif msg['metadata']['status'] == 'error':
|
||||
return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True)))
|
||||
|
||||
if 'text' in output:
|
||||
res_type = "text"
|
||||
res = output['text']
|
||||
elif 'data' in output:
|
||||
for key in output['data']:
|
||||
if 'image/png' in key:
|
||||
res_type = "image"
|
||||
res = output['data'][key]
|
||||
break
|
||||
elif 'text/plain' in key:
|
||||
res_type = "text"
|
||||
res = output['data'][key]
|
||||
|
||||
if res_type == "image":
|
||||
return res_type, b64_2_img(res)
|
||||
elif res_type == "text" or res_type == "traceback":
|
||||
res = res
|
||||
|
||||
return res_type, res
|
||||
|
||||
|
||||
def extract_code(text: str) -> str:
|
||||
pattern = r'```([^\n]*)\n(.*?)```'
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
return matches[-1][1]
|
||||
@@ -1,186 +0,0 @@
|
||||
# This is an example that uses the websockets api to know when a prompt execution is done
|
||||
# Once the prompt execution is done it downloads the images using the /history endpoint
|
||||
|
||||
import io
|
||||
import json
|
||||
import random
|
||||
import urllib
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
# NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
||||
import websocket
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ComfyUIApi():
|
||||
def __init__(self, server_address="127.0.0.1:8188"):
|
||||
self.server_address = server_address
|
||||
self.client_id = str(uuid.uuid4())
|
||||
self.ws = websocket.WebSocket()
|
||||
self.ws.connect(
|
||||
"ws://{}/ws?clientId={}".format(server_address, self.client_id))
|
||||
|
||||
def queue_prompt(self, prompt):
|
||||
p = {"prompt": prompt, "client_id": self.client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = requests.post(
|
||||
"http://{}/prompt".format(self.server_address), data=data)
|
||||
print(req.text)
|
||||
return json.loads(req.text)
|
||||
|
||||
def get_image(self, filename, subfolder, folder_type):
|
||||
data = {"filename": filename,
|
||||
"subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
with requests.get("http://{}/view?{}".format(self.server_address, url_values)) as response:
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
return image
|
||||
|
||||
def get_image_url(self, filename, subfolder, folder_type):
|
||||
data = {"filename": filename,
|
||||
"subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
return "http://{}/view?{}".format(self.server_address, url_values)
|
||||
|
||||
def get_history(self, prompt_id):
|
||||
with requests.get("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
|
||||
return json.loads(response.text)
|
||||
|
||||
def get_images(self, prompt, isUrl=False):
|
||||
prompt_id = self.queue_prompt(prompt)['prompt_id']
|
||||
output_images = []
|
||||
while True:
|
||||
out = self.ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
if message['type'] == 'executing':
|
||||
data = message['data']
|
||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||
break # Execution is done
|
||||
else:
|
||||
continue # previews are binary data
|
||||
|
||||
history = self.get_history(prompt_id)[prompt_id]
|
||||
for o in history['outputs']:
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
if 'images' in node_output:
|
||||
for image in node_output['images']:
|
||||
image_data = self.get_image_url(image['filename'], image['subfolder'], image['type']) if isUrl else self.get_image(
|
||||
image['filename'], image['subfolder'], image['type'])
|
||||
image['image'] = image_data
|
||||
output_images.append(image)
|
||||
|
||||
return output_images
|
||||
|
||||
|
||||
prompt_text = """
|
||||
{
|
||||
"3": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"cfg": 8,
|
||||
"denoise": 1,
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
],
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"seed": 8566257,
|
||||
"steps": 20
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {
|
||||
"ckpt_name": "chilloutmix_NiPrunedFp32Fix.safetensors"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"class_type": "EmptyLatentImage",
|
||||
"inputs": {
|
||||
"batch_size": 1,
|
||||
"height": 512,
|
||||
"width": 512
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
],
|
||||
"text": "masterpiece best quality girl"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
],
|
||||
"text": "bad hands"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"class_type": "SaveImage",
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
prompt = json.loads(prompt_text)
|
||||
# set the text prompt for our positive CLIPTextEncode
|
||||
prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
|
||||
|
||||
# set the seed for our KSampler node
|
||||
prompt["3"]["inputs"]["seed"] = ''.join(
|
||||
random.sample('123456789012345678901234567890', 14))
|
||||
|
||||
cfui = ComfyUIApi()
|
||||
images = cfui.get_images(prompt)
|
||||
|
||||
# Commented out code to display the output images:
|
||||
|
||||
for node_id in images:
|
||||
for image_data in images[node_id]:
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
image.show()
|
||||
@@ -1,167 +0,0 @@
|
||||
import inspect
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from types import GenericAlias
|
||||
from typing import Annotated, get_origin
|
||||
|
||||
from ai_providers.chatglm.comfyUI_api import ComfyUIApi
|
||||
from function.func_news import News
|
||||
from zhdate import ZhDate
|
||||
|
||||
_TOOL_HOOKS = {}
|
||||
_TOOL_DESCRIPTIONS = {}
|
||||
|
||||
|
||||
def extract_code(text: str) -> str:
|
||||
pattern = r'```([^\n]*)\n(.*?)```'
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
return matches[-1][1]
|
||||
|
||||
|
||||
def register_tool(func: callable):
|
||||
tool_name = func.__name__
|
||||
tool_description = inspect.getdoc(func).strip()
|
||||
python_params = inspect.signature(func).parameters
|
||||
tool_params = []
|
||||
for name, param in python_params.items():
|
||||
annotation = param.annotation
|
||||
if annotation is inspect.Parameter.empty:
|
||||
raise TypeError(f"Parameter `{name}` missing type annotation")
|
||||
if get_origin(annotation) != Annotated:
|
||||
raise TypeError(
|
||||
f"Annotation type for `{name}` must be typing.Annotated")
|
||||
|
||||
typ, (description, required) = annotation.__origin__, annotation.__metadata__
|
||||
typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
|
||||
if not isinstance(description, str):
|
||||
raise TypeError(f"Description for `{name}` must be a string")
|
||||
if not isinstance(required, bool):
|
||||
raise TypeError(f"Required for `{name}` must be a bool")
|
||||
|
||||
tool_params.append({
|
||||
"name": name,
|
||||
"description": description,
|
||||
"type": typ,
|
||||
"required": required
|
||||
})
|
||||
tool_def = {
|
||||
"name": tool_name,
|
||||
"description": tool_description,
|
||||
"parameters": tool_params
|
||||
}
|
||||
|
||||
# print("[registered tool] " + pformat(tool_def))
|
||||
_TOOL_HOOKS[tool_name] = func
|
||||
_TOOL_DESCRIPTIONS[tool_name] = tool_def
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def dispatch_tool(tool_name: str, tool_params: dict) -> str:
|
||||
if tool_name not in _TOOL_HOOKS:
|
||||
return f"Tool `{tool_name}` not found. Please use a provided tool."
|
||||
tool_call = _TOOL_HOOKS[tool_name]
|
||||
try:
|
||||
ret = tool_call(**tool_params)
|
||||
except BaseException:
|
||||
ret = traceback.format_exc()
|
||||
return ret
|
||||
|
||||
|
||||
def get_tools() -> dict:
|
||||
return deepcopy(_TOOL_DESCRIPTIONS)
|
||||
|
||||
# Tool Definitions
|
||||
|
||||
# @register_tool
|
||||
# def random_number_generator(
|
||||
# seed: Annotated[int, 'The random seed used by the generator', True],
|
||||
# range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
|
||||
# ) -> int:
|
||||
# """
|
||||
# Generates a random number x, s.t. range[0] <= x < range[1]
|
||||
# """
|
||||
# if not isinstance(seed, int):
|
||||
# raise TypeError("Seed must be an integer")
|
||||
# if not isinstance(range, tuple):
|
||||
# raise TypeError("Range must be a tuple")
|
||||
# if not isinstance(range[0], int) or not isinstance(range[1], int):
|
||||
# raise TypeError("Range must be a tuple of integers")
|
||||
|
||||
# import random
|
||||
# return random.Random(seed).randint(*range)
|
||||
|
||||
|
||||
@register_tool
|
||||
def get_weather(
|
||||
city_name: Annotated[str, 'The name of the city to be queried', True],
|
||||
) -> str:
|
||||
"""
|
||||
Get the current weather for `city_name`
|
||||
"""
|
||||
if not isinstance(city_name, str):
|
||||
raise TypeError("City name must be a string")
|
||||
|
||||
key_selection = {
|
||||
"current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
|
||||
}
|
||||
import requests
|
||||
try:
|
||||
resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
|
||||
resp.raise_for_status()
|
||||
resp = resp.json()
|
||||
ret = {k: {_v: resp[k][0][_v] for _v in v}
|
||||
for k, v in key_selection.items()}
|
||||
except BaseException:
|
||||
import traceback
|
||||
ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
|
||||
|
||||
return str(ret)
|
||||
|
||||
|
||||
@register_tool
|
||||
def get_confyui_image(prompt: Annotated[str, '要生成图片的提示词,注意必须是英文', True]) -> dict:
|
||||
'''
|
||||
生成图片
|
||||
'''
|
||||
with open("ai_providers/chatglm/base.json", "r", encoding="utf-8") as f:
|
||||
data2 = json.load(f)
|
||||
data2['prompt']['3']['inputs']['seed'] = ''.join(
|
||||
random.sample('123456789012345678901234567890', 14))
|
||||
# 模型名称
|
||||
data2['prompt']['4']['inputs']['ckpt_name'] = 'chilloutmix_NiPrunedFp32Fix.safetensors'
|
||||
data2['prompt']['6']['inputs']['text'] = prompt # 正向提示词
|
||||
# data2['prompt']['7']['inputs']['text']='' #反向提示词
|
||||
cfui = ComfyUIApi(server_address="127.0.0.1:8188") # 根据自己comfyUI地址修改
|
||||
images = cfui.get_images(data2['prompt'])
|
||||
return {'res': images[0]['image'], 'res_type': 'image', 'filename': images[0]['filename']}
|
||||
|
||||
|
||||
@register_tool
|
||||
def get_news() -> str:
|
||||
'''
|
||||
获取最新新闻
|
||||
'''
|
||||
news = News()
|
||||
return news.get_important_news()
|
||||
|
||||
|
||||
@register_tool
|
||||
def get_time() -> str:
|
||||
'''
|
||||
获取当前日期,时间,农历日期,星期几
|
||||
'''
|
||||
time = datetime.now()
|
||||
date2 = ZhDate.from_datetime(time)
|
||||
week_list = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"]
|
||||
|
||||
return '{} {} {}'.format(time.strftime("%Y年%m月%d日 %H:%M:%S"), week_list[time.weekday()], '农历:' + date2.chinese())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(dispatch_tool("get_weather", {"city_name": "beijing"}))
|
||||
print(get_tools())
|
||||
@@ -9,8 +9,6 @@ from function.func_duel import DuelRankSystem
|
||||
# 导入AI模型
|
||||
from ai_providers.ai_deepseek import DeepSeek
|
||||
from ai_providers.ai_chatgpt import ChatGPT
|
||||
from ai_providers.ai_chatglm import ChatGLM
|
||||
from ai_providers.ai_ollama import Ollama
|
||||
|
||||
# 前向引用避免循环导入
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -129,29 +127,6 @@ def handle_reset_memory(ctx: 'MessageContext', match: Optional[Match]) -> bool:
|
||||
result = "✅ 已重置ChatGPT对话记忆,保留系统提示,开始新的对话"
|
||||
else:
|
||||
result = f"⚠️ {model_name} 对话记忆已为空,无需重置"
|
||||
|
||||
elif isinstance(chat_model, ChatGLM):
|
||||
# ChatGLM模型
|
||||
if hasattr(chat_model, 'chat_type') and chat_id in chat_model.chat_type:
|
||||
chat_type = chat_model.chat_type[chat_id]
|
||||
# 保留系统提示,删除对话历史
|
||||
if chat_type in chat_model.conversation_list[chat_id]:
|
||||
chat_model.conversation_list[chat_id][chat_type] = []
|
||||
if ctx.logger: ctx.logger.info(f"已重置ChatGLM对话记忆: {chat_id}")
|
||||
result = "✅ 已重置ChatGLM对话记忆,开始新的对话"
|
||||
else:
|
||||
result = f"⚠️ 未找到与 {model_name} 的对话记忆,无需重置"
|
||||
else:
|
||||
result = f"⚠️ 未找到与 {model_name} 的对话记忆,无需重置"
|
||||
|
||||
elif isinstance(chat_model, Ollama):
|
||||
# Ollama模型
|
||||
if chat_id in chat_model.conversation_list:
|
||||
chat_model.conversation_list[chat_id] = []
|
||||
if ctx.logger: ctx.logger.info(f"已重置Ollama对话记忆: {chat_id}")
|
||||
result = "✅ 已重置Ollama对话记忆,开始新的对话"
|
||||
else:
|
||||
result = f"⚠️ 未找到与 {model_name} 的对话记忆,无需重置"
|
||||
|
||||
else:
|
||||
# 通用处理方式:直接删除对话记录
|
||||
|
||||
@@ -54,28 +54,23 @@ groups:
|
||||
models:
|
||||
# 模型ID参考:
|
||||
# 0: 自动选择第一个可用模型
|
||||
# 1: TigerBot
|
||||
# 2: ChatGPT
|
||||
# 3: 讯飞星火
|
||||
# 4: ChatGLM
|
||||
# 5: BardAssistant/Gemini
|
||||
# 6: 智谱ZhiPu
|
||||
# 7: Ollama
|
||||
# 8: DeepSeek
|
||||
# 9: Perplexity
|
||||
# 1: ChatGPT
|
||||
# 2: DeepSeek
|
||||
default: 0 # 默认模型ID(0表示自动选择第一个可用模型)
|
||||
# 群聊映射
|
||||
mapping:
|
||||
- room_id: example12345@chatroom
|
||||
model: 2 # 对应ChatType.CHATGPT
|
||||
model: 2
|
||||
- room_id: example12345@chatroom
|
||||
model: 7 # 对应ChatType.OLLAMA
|
||||
model: 7
|
||||
# 私聊映射
|
||||
private_mapping:
|
||||
- wxid: filehelper
|
||||
model: 2 # 对应ChatType.CHATGPT
|
||||
model: 2
|
||||
- wxid: wxid_example12345
|
||||
model: 8 # 对应ChatType.DEEPSEEK
|
||||
model: 8
|
||||
|
||||
MAX_HISTORY: 300 # 记录数据库的消息历史
|
||||
|
||||
news:
|
||||
receivers: ["filehelper"] # 定时新闻接收人(roomid 或者 wxid)
|
||||
@@ -96,41 +91,7 @@ chatgpt: # -----chatgpt配置这行不填-----
|
||||
model: gpt-3.5-turbo # 可选:gpt-3.5-turbo、gpt-4、gpt-4-turbo、gpt-4.1-mini、o4-mini
|
||||
proxy: # 如果你在国内,你可能需要魔法,大概长这样:http://域名或者IP地址:端口号
|
||||
prompt: 你是智能聊天机器人,你叫 wcferry # 根据需要对角色进行设定
|
||||
|
||||
chatglm: # -----chatglm配置这行不填-----
|
||||
key: # 这个应该不用动
|
||||
api: http://localhost:8000/v1 # 根据自己的chatglm地址修改
|
||||
proxy: # 如果你在国内,你可能需要魔法,大概长这样:http://域名或者IP地址:端口号
|
||||
prompt: 你是智能聊天机器人,你叫小薇 # 根据需要对角色进行设定
|
||||
file_path: F:/Pictures/temp #设定生成图片和代码使用的文件夹路径
|
||||
|
||||
ollama: # -----ollama配置这行不填-----
|
||||
enable: true # 是否启用 ollama
|
||||
model: deepseek-r1:1.5b # ollama-7b-sft
|
||||
prompt: 你是智能聊天机器人,你叫 梅好事 # 根据需要对角色进行设定
|
||||
file_path: d:/pictures/temp #设定生成图片和代码使用的文件夹路径
|
||||
|
||||
tigerbot: # -----tigerbot配置这行不填-----
|
||||
key: # key
|
||||
model: # tigerbot-7b-sft
|
||||
|
||||
xinghuo_web: # -----讯飞星火web模式api配置这行不填 抓取方式详见文档:https://www.bilibili.com/read/cv27066577-----
|
||||
cookie: # cookie
|
||||
fd: # fd
|
||||
GtToken: # GtToken
|
||||
prompt: 你是智能聊天机器人,你叫 wcferry。请用这个角色回答我的问题 # 根据需要对角色进行设定
|
||||
|
||||
bard: # -----bard配置这行不填-----
|
||||
api_key: # api-key 创建地址:https://ai.google.dev/pricing?hl=en,创建后复制过来即可
|
||||
model_name: gemini-pro # 新模型上线后可以选择模型
|
||||
proxy: http://127.0.0.1:7890 # 如果你在国内,你可能需要魔法,大概长这样:http://域名或者IP地址:端口号
|
||||
# 提示词尽可能用英文,bard对中文提示词的效果不是很理想,下方提示词为英语老师的示例,请按实际需要修改,默认设置的提示词为谷歌创造的AI大语言模型
|
||||
# I want you to act as a spoken English teacher and improver. I will speak to you in English and you will reply to me in English to practice my spoken English. I want you to keep your reply neat, limiting the reply to 100 words. I want you to strictly correct my grammar mistakes, typos, and factual errors. I want you to ask me a question in your reply. Now let's start practicing, you could ask me a question first. Remember, I want you to strictly correct my grammar mistakes, typos, and factual errors.
|
||||
prompt: You am a large language model, trained by Google.
|
||||
|
||||
zhipu: # -----zhipu配置这行不填-----
|
||||
api_key: #api key
|
||||
model: # 模型类型
|
||||
max_history_messages: 20 # <--- 添加这一行,设置 ChatGPT 最多回顾 20 条历史消息
|
||||
|
||||
deepseek: # -----deepseek配置这行不填-----
|
||||
#思维链相关功能默认关闭,开启后会增加响应时间和消耗更多的token
|
||||
@@ -140,17 +101,7 @@ deepseek: # -----deepseek配置这行不填-----
|
||||
prompt: 你是智能聊天机器人,你叫 DeepSeek 助手 # 根据需要对角色进行设定
|
||||
enable_reasoning: false # 是否启用思维链功能,仅在使用 deepseek-reasoner 模型时有效
|
||||
show_reasoning: false # 是否在回复中显示思维过程,仅在启用思维链功能时有效
|
||||
|
||||
cogview: # -----智谱AI图像生成配置这行不填-----
|
||||
# 此API请参考 https://www.bigmodel.cn/dev/api/image-model/cogview
|
||||
enable: False # 是否启用图像生成功能,默认关闭,将False替换为true则开启,此模型可和其他模型同时运行。
|
||||
api_key: # 智谱API密钥,请填入您的API Key
|
||||
model: cogview-4-250304 # 模型编码,可选:cogview-4-250304、cogview-4、cogview-3-flash
|
||||
quality: standard # 生成质量,可选:standard(快速)、hd(高清)
|
||||
size: 1024x1024 # 图片尺寸,可自定义,需符合条件
|
||||
trigger_keyword: 牛智谱 # 触发图像生成的关键词
|
||||
temp_dir: # 临时文件存储目录,留空则默认使用项目目录下的zhipuimg文件夹,如果要更改,例如 D:/Pictures/temp 或 /home/user/temp
|
||||
fallback_to_chat: true # 当未启用绘画功能时:true=将请求发给聊天模型处理,false=回复固定的未启用提示信息
|
||||
max_history_messages: 10 # <--- 添加这一行,设置 DeepSeek 最多回顾 10 条历史消息
|
||||
|
||||
aliyun_image: # -----如果要使用阿里云文生图,取消下面的注释并填写相关内容,模型到阿里云百炼找通义万相-文生图2.1-Turbo-----
|
||||
enable: true # 是否启用阿里文生图功能,false为关闭,默认开启,如果未配置,则会将消息发送给聊天大模型
|
||||
|
||||
@@ -37,15 +37,8 @@ class Config(object):
|
||||
self.REPORT_REMINDERS = yconfig["report_reminder"]["receivers"]
|
||||
|
||||
self.CHATGPT = yconfig.get("chatgpt", {})
|
||||
self.OLLAMA = yconfig.get("ollama", {})
|
||||
self.TIGERBOT = yconfig.get("tigerbot", {})
|
||||
self.XINGHUO_WEB = yconfig.get("xinghuo_web", {})
|
||||
self.CHATGLM = yconfig.get("chatglm", {})
|
||||
self.BardAssistant = yconfig.get("bard", {})
|
||||
self.ZhiPu = yconfig.get("zhipu", {})
|
||||
self.DEEPSEEK = yconfig.get("deepseek", {})
|
||||
self.PERPLEXITY = yconfig.get("perplexity", {})
|
||||
self.COGVIEW = yconfig.get("cogview", {})
|
||||
self.ALIYUN_IMAGE = yconfig.get("aliyun_image", {})
|
||||
self.GEMINI_IMAGE = yconfig.get("gemini_image", {})
|
||||
self.SEND_RATE_LIMIT = yconfig.get("send_rate_limit", 0)
|
||||
|
||||
18
constants.py
18
constants.py
@@ -4,22 +4,14 @@ from enum import IntEnum, unique
|
||||
@unique
|
||||
class ChatType(IntEnum):
|
||||
# UnKnown = 0 # 未知, 即未设置
|
||||
TIGER_BOT = 1 # TigerBot
|
||||
CHATGPT = 2 # ChatGPT
|
||||
XINGHUO_WEB = 3 # 讯飞星火
|
||||
CHATGLM = 4 # ChatGLM
|
||||
BardAssistant = 5 # Google Bard
|
||||
ZhiPu = 6 # ZhiPu
|
||||
OLLAMA = 7 # Ollama
|
||||
DEEPSEEK = 8 # DeepSeek
|
||||
PERPLEXITY = 9 # Perplexity
|
||||
CHATGPT = 1 # ChatGPT
|
||||
DEEPSEEK = 2 # DeepSeek
|
||||
PERPLEXITY = 3 # Perplexity
|
||||
|
||||
@staticmethod
|
||||
def is_in_chat_types(chat_type: int) -> bool:
|
||||
if chat_type in [ChatType.TIGER_BOT.value, ChatType.CHATGPT.value,
|
||||
ChatType.XINGHUO_WEB.value, ChatType.CHATGLM.value,
|
||||
ChatType.BardAssistant.value, ChatType.ZhiPu.value,
|
||||
ChatType.OLLAMA.value, ChatType.DEEPSEEK.value,
|
||||
if chat_type in [ChatType.CHATGPT.value,
|
||||
ChatType.DEEPSEEK.value,
|
||||
ChatType.PERPLEXITY.value]:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -8,74 +8,84 @@ from collections import deque
|
||||
import sqlite3 # 添加sqlite3模块
|
||||
import os # 用于处理文件路径
|
||||
from function.func_xml_process import XmlProcessor # 导入XmlProcessor
|
||||
from commands.registry import COMMANDS # 导入命令列表
|
||||
# from commands.registry import COMMANDS # 不再需要导入命令列表
|
||||
|
||||
class MessageSummary:
|
||||
"""消息总结功能类 (使用SQLite持久化)
|
||||
用于记录、管理和生成聊天历史消息的总结
|
||||
"""
|
||||
|
||||
def __init__(self, max_history=300, db_path="data/message_history.db"):
|
||||
|
||||
def __init__(self, max_history=200, db_path="data/message_history.db"): # 默认max_history 改为 200
|
||||
"""初始化消息总结功能
|
||||
|
||||
|
||||
Args:
|
||||
max_history: 每个聊天保存的最大消息数量
|
||||
db_path: SQLite数据库文件路径
|
||||
"""
|
||||
self.LOG = logging.getLogger("MessageSummary")
|
||||
self.max_history = max_history
|
||||
self.max_history = max_history # 使用传入的 max_history
|
||||
self.db_path = db_path
|
||||
|
||||
|
||||
# 实例化XML处理器用于提取引用消息
|
||||
self.xml_processor = XmlProcessor(self.LOG)
|
||||
|
||||
# 移除旧的内存存储相关代码
|
||||
# self._msg_history = {} # 使用字典,以群ID或用户ID为键
|
||||
# self._msg_history_lock = Lock() # 添加锁以保证线程安全
|
||||
|
||||
|
||||
try:
|
||||
# 确保数据库文件所在的目录存在
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
if db_dir and not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir)
|
||||
self.LOG.info(f"创建数据库目录: {db_dir}")
|
||||
|
||||
# 连接到数据库 (如果文件不存在会自动创建)
|
||||
# check_same_thread=False 允许在不同线程中使用此连接
|
||||
# 这在多线程机器人应用中是必要的,但要注意事务管理
|
||||
|
||||
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
self.cursor = self.conn.cursor()
|
||||
self.LOG.info(f"已连接到 SQLite 数据库: {self.db_path}")
|
||||
|
||||
# 创建消息表 (如果不存在)
|
||||
# 使用 INTEGER PRIMARY KEY AUTOINCREMENT 作为 rowid 的别名,方便管理
|
||||
# timestamp_float 用于排序和限制数量
|
||||
# timestamp_str 用于显示
|
||||
|
||||
# ---- 修改数据库表结构 ----
|
||||
# 检查并添加 sender_wxid 列 (如果不存在)
|
||||
self.cursor.execute("PRAGMA table_info(messages)")
|
||||
columns = [col[1] for col in self.cursor.fetchall()]
|
||||
if 'sender_wxid' not in columns:
|
||||
try:
|
||||
self.cursor.execute("ALTER TABLE messages ADD COLUMN sender_wxid TEXT")
|
||||
self.conn.commit()
|
||||
self.LOG.info("已向 messages 表添加 sender_wxid 列")
|
||||
except sqlite3.OperationalError as e:
|
||||
# 如果表是空的,直接删除重建可能更简单
|
||||
self.LOG.warning(f"添加 sender_wxid 列失败 ({e}),可能是因为表非空且有主键?尝试重建表。")
|
||||
# 注意:这会丢失现有数据!
|
||||
self.cursor.execute("DROP TABLE IF EXISTS messages")
|
||||
self.conn.commit()
|
||||
|
||||
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
chat_id TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
sender_wxid TEXT, -- 新增: 存储发送者wxid
|
||||
content TEXT NOT NULL,
|
||||
timestamp_float REAL NOT NULL,
|
||||
timestamp_str TEXT NOT NULL
|
||||
timestamp_str TEXT NOT NULL -- 存储完整时间格式 YYYY-MM-DD HH:MM:SS
|
||||
)
|
||||
""")
|
||||
# 为 chat_id 和 timestamp_float 创建索引,提高查询效率
|
||||
# ---- 数据库表结构修改结束 ----
|
||||
|
||||
self.cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_time ON messages (chat_id, timestamp_float)
|
||||
""")
|
||||
# 新增 sender_wxid 索引 (可选,如果经常需要按wxid查询)
|
||||
self.cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_sender_wxid ON messages (sender_wxid)
|
||||
""")
|
||||
self.conn.commit() # 提交更改
|
||||
self.LOG.info("消息表已准备就绪")
|
||||
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self.LOG.error(f"数据库初始化失败: {e}")
|
||||
# 如果数据库连接失败,抛出异常或进行其他错误处理
|
||||
raise ConnectionError(f"无法连接或初始化数据库: {e}") from e
|
||||
except OSError as e:
|
||||
self.LOG.error(f"创建数据库目录失败: {e}")
|
||||
raise OSError(f"无法创建数据库目录: {e}") from e
|
||||
|
||||
|
||||
def close_db(self):
|
||||
"""关闭数据库连接"""
|
||||
if hasattr(self, 'conn') and self.conn:
|
||||
@@ -85,34 +95,46 @@ class MessageSummary:
|
||||
self.LOG.info("数据库连接已关闭")
|
||||
except sqlite3.Error as e:
|
||||
self.LOG.error(f"关闭数据库连接时出错: {e}")
|
||||
|
||||
def record_message(self, chat_id, sender_name, content, timestamp=None):
|
||||
|
||||
# ---- 修改 record_message ----
|
||||
def record_message(self, chat_id, sender_name, sender_wxid, content, timestamp=None):
|
||||
"""记录单条消息到数据库
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID(群ID或用户ID)
|
||||
sender_name: 发送者名称
|
||||
sender_wxid: 发送者wxid
|
||||
content: 消息内容
|
||||
timestamp: 时间戳,默认为当前时间
|
||||
timestamp: 外部提供的时间字符串(优先使用),否则生成
|
||||
"""
|
||||
try:
|
||||
# 生成浮点数时间戳用于排序
|
||||
current_time_float = time.time()
|
||||
|
||||
# 生成或使用传入的时间字符串
|
||||
|
||||
# ---- 修改时间格式 ----
|
||||
if not timestamp:
|
||||
timestamp_str = time.strftime("%H:%M", time.localtime(current_time_float))
|
||||
# 默认使用完整时间格式
|
||||
timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(current_time_float))
|
||||
else:
|
||||
timestamp_str = timestamp
|
||||
|
||||
# 插入新消息
|
||||
# 如果传入的时间戳只有时分,转换为完整格式
|
||||
if len(timestamp) <= 5: # 如果格式是 "HH:MM"
|
||||
today = time.strftime("%Y-%m-%d", time.localtime(current_time_float))
|
||||
timestamp_str = f"{today} {timestamp}:00" # 补上秒
|
||||
elif len(timestamp) == 8 and timestamp.count(':') == 2: # 如果格式是 "HH:MM:SS"
|
||||
today = time.strftime("%Y-%m-%d", time.localtime(current_time_float))
|
||||
timestamp_str = f"{today} {timestamp}"
|
||||
elif len(timestamp) == 16 and timestamp.count('-') == 2 and timestamp.count(':') == 1: # "YYYY-MM-DD HH:MM"
|
||||
timestamp_str = f"{timestamp}:00" # 补上秒
|
||||
else:
|
||||
timestamp_str = timestamp # 假设是完整格式
|
||||
|
||||
# 插入新消息,包含 sender_wxid
|
||||
self.cursor.execute("""
|
||||
INSERT INTO messages (chat_id, sender, content, timestamp_float, timestamp_str)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (chat_id, sender_name, content, current_time_float, timestamp_str))
|
||||
|
||||
INSERT INTO messages (chat_id, sender, sender_wxid, content, timestamp_float, timestamp_str)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (chat_id, sender_name, sender_wxid, content, current_time_float, timestamp_str))
|
||||
# ---- 时间格式和插入修改结束 ----
|
||||
|
||||
# 删除超出 max_history 的旧消息
|
||||
# 使用子查询找到要保留的最新 N 条消息的 id,然后删除不在这个列表中的该 chat_id 的其他消息
|
||||
self.cursor.execute("""
|
||||
DELETE FROM messages
|
||||
WHERE chat_id = ? AND id NOT IN (
|
||||
@@ -123,275 +145,267 @@ class MessageSummary:
|
||||
LIMIT ?
|
||||
)
|
||||
""", (chat_id, chat_id, self.max_history))
|
||||
|
||||
|
||||
self.conn.commit() # 提交事务
|
||||
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self.LOG.error(f"记录消息到数据库时出错: {e}")
|
||||
# 可以考虑回滚事务
|
||||
try:
|
||||
self.conn.rollback()
|
||||
except:
|
||||
pass
|
||||
|
||||
# ---- record_message 修改结束 ----
|
||||
|
||||
def clear_message_history(self, chat_id):
|
||||
"""清除指定聊天的消息历史记录
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID(群ID或用户ID)
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功清除
|
||||
"""
|
||||
try:
|
||||
# 删除指定chat_id的所有消息
|
||||
self.cursor.execute("DELETE FROM messages WHERE chat_id = ?", (chat_id,))
|
||||
rows_deleted = self.cursor.rowcount # 获取删除的行数
|
||||
rows_deleted = self.cursor.rowcount
|
||||
self.conn.commit()
|
||||
self.LOG.info(f"为 chat_id={chat_id} 清除了 {rows_deleted} 条历史消息")
|
||||
return True # 删除0条也视为成功完成操作
|
||||
|
||||
return True
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self.LOG.error(f"清除消息历史时出错 (chat_id={chat_id}): {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_message_count(self, chat_id):
|
||||
"""获取指定聊天的消息数量
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID(群ID或用户ID)
|
||||
|
||||
|
||||
Returns:
|
||||
int: 消息数量
|
||||
"""
|
||||
try:
|
||||
# 使用COUNT查询获取消息数量
|
||||
self.cursor.execute("SELECT COUNT(*) FROM messages WHERE chat_id = ?", (chat_id,))
|
||||
result = self.cursor.fetchone() # fetchone() 返回一个元组,例如 (5,)
|
||||
result = self.cursor.fetchone()
|
||||
return result[0] if result else 0
|
||||
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self.LOG.error(f"获取消息数量时出错 (chat_id={chat_id}): {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# ---- 修改 get_messages ----
|
||||
def get_messages(self, chat_id):
|
||||
"""获取指定聊天的所有消息 (按时间升序)
|
||||
|
||||
"""获取指定聊天的所有消息 (按时间升序),包含发送者wxid和完整时间戳
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID(群ID或用户ID)
|
||||
|
||||
|
||||
Returns:
|
||||
list: 消息列表,格式为 [{"sender": ..., "content": ..., "time": ...}]
|
||||
list: 消息列表,格式为 [{"sender": ..., "sender_wxid": ..., "content": ..., "time": ...}]
|
||||
"""
|
||||
messages = []
|
||||
try:
|
||||
# 查询需要的字段,按浮点时间戳升序排序,限制数量
|
||||
# 查询需要的字段,包括 sender_wxid 和 timestamp_str
|
||||
self.cursor.execute("""
|
||||
SELECT sender, content, timestamp_str
|
||||
SELECT sender, sender_wxid, content, timestamp_str
|
||||
FROM messages
|
||||
WHERE chat_id = ?
|
||||
ORDER BY timestamp_float ASC
|
||||
LIMIT ?
|
||||
""", (chat_id, self.max_history))
|
||||
|
||||
rows = self.cursor.fetchall() # fetchall() 返回包含元组的列表
|
||||
|
||||
|
||||
rows = self.cursor.fetchall()
|
||||
|
||||
# 将数据库行转换为期望的字典列表格式
|
||||
for row in rows:
|
||||
messages.append({
|
||||
"sender": row[0],
|
||||
"content": row[1],
|
||||
"time": row[2] # 使用存储的 timestamp_str
|
||||
"sender_wxid": row[1], # 添加 sender_wxid
|
||||
"content": row[2],
|
||||
"time": row[3] # 使用存储的完整 timestamp_str
|
||||
})
|
||||
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self.LOG.error(f"获取消息列表时出错 (chat_id={chat_id}): {e}")
|
||||
# 出错时返回空列表,保持与原逻辑一致
|
||||
|
||||
|
||||
return messages
|
||||
|
||||
# ---- get_messages 修改结束 ----
|
||||
|
||||
def _basic_summarize(self, messages):
|
||||
"""基本的消息总结逻辑,不使用AI
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
messages: 消息列表 (格式同 get_messages 返回值)
|
||||
|
||||
Returns:
|
||||
str: 消息总结
|
||||
"""
|
||||
if not messages:
|
||||
return "没有可以总结的历史消息。"
|
||||
|
||||
# 构建总结
|
||||
|
||||
res = ["以下是近期聊天记录摘要:\n"]
|
||||
for msg in messages:
|
||||
# 使用新的时间格式和发送者
|
||||
res.append(f"[{msg['time']}]{msg['sender']}: {msg['content']}")
|
||||
|
||||
|
||||
return "\n".join(res)
|
||||
|
||||
|
||||
def _ai_summarize(self, messages, chat_model, chat_id):
|
||||
"""使用AI模型生成消息总结
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
messages: 消息列表 (格式同 get_messages 返回值)
|
||||
chat_model: AI聊天模型对象
|
||||
chat_id: 聊天ID
|
||||
|
||||
|
||||
Returns:
|
||||
str: 消息总结
|
||||
"""
|
||||
if not messages:
|
||||
return "没有可以总结的历史消息。"
|
||||
|
||||
# 构建用于AI总结的消息格式
|
||||
|
||||
formatted_msgs = []
|
||||
for msg in messages:
|
||||
# 使用新的时间格式和发送者
|
||||
formatted_msgs.append(f"[{msg['time']}]{msg['sender']}: {msg['content']}")
|
||||
|
||||
# 构建提示词 - 更加客观、中立
|
||||
|
||||
# 构建提示词 ... (保持不变)
|
||||
prompt = (
|
||||
"请仔细阅读并分析以下聊天记录,生成一简要的、结构清晰且抓住重点的摘要。\n\n"
|
||||
"你是泡泡,请仔细阅读并分析以下聊天记录,生成一简要的、结构清晰且抓住重点的摘要。\n\n"
|
||||
"摘要格式要求:\n"
|
||||
"1. 使用数字编号列表 (例如 1., 2., 3.) 来组织内容,每个编号代表一个独立的主要讨论主题,不要超过3个主题。\n"
|
||||
"2. 在每个编号的主题下,写成一段不带格式的文字,每个主题单独成段并空行,需包含以下内容:\n"
|
||||
" - 这个讨论的核心的简要描述。\n"
|
||||
" - 该讨论的关键成员 (用括号 [用户名] 格式) 和他们的关键发言内容、成员之间的关键互动。\n"
|
||||
" - 该讨论的讨论结果。\n"
|
||||
"3. 总结需客观、精炼、简短精悍,直接呈现最核心且精简的事实,尽量不要添加额外的评论或分析。\n"
|
||||
"3. 总结需客观、精炼、简短精悍,直接呈现最核心且精简的事实,尽量不要添加额外的评论或分析,不要总结有关自己的事情。\n"
|
||||
"4. 不要暴露出格式,不要说核心是xxx参与者是xxx结果是xxx,自然一点。\n\n"
|
||||
"聊天记录如下:\n" + "\n".join(formatted_msgs)
|
||||
)
|
||||
|
||||
# 使用AI模型生成总结 - 创建一个临时的聊天会话ID,避免污染正常对话上下文
|
||||
|
||||
try:
|
||||
# 对于支持新会话参数的模型,使用特殊标记告知这是独立的总结请求
|
||||
if hasattr(chat_model, 'get_answer_with_context') and callable(getattr(chat_model, 'get_answer_with_context')):
|
||||
# 使用带上下文参数的方法
|
||||
summary = chat_model.get_answer_with_context(prompt, "summary_" + chat_id, clear_context=True)
|
||||
else:
|
||||
# 普通方法,使用特殊会话ID
|
||||
summary = chat_model.get_answer(prompt, "summary_" + chat_id)
|
||||
|
||||
# 调用AI部分保持不变,但现在AI模型内部应使用数据库历史记录
|
||||
# 确保调用 get_answer 时,AI模型实例已经关联了 MessageSummary
|
||||
summary = chat_model.get_answer(prompt, f"summary_{chat_id}") # 使用特殊 wxid 避免冲突
|
||||
|
||||
|
||||
if not summary:
|
||||
return self._basic_summarize(messages)
|
||||
|
||||
|
||||
return summary
|
||||
except Exception as e:
|
||||
self.LOG.error(f"使用AI生成总结失败: {e}")
|
||||
return self._basic_summarize(messages)
|
||||
|
||||
|
||||
def summarize_messages(self, chat_id, chat_model=None):
|
||||
"""生成消息总结
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID(群ID或用户ID)
|
||||
chat_model: AI聊天模型对象,如果为None则使用基础总结
|
||||
|
||||
|
||||
Returns:
|
||||
str: 消息总结
|
||||
"""
|
||||
messages = self.get_messages(chat_id)
|
||||
if not messages:
|
||||
return "没有可以总结的历史消息。"
|
||||
|
||||
# 根据是否提供了AI模型决定使用哪种总结方式
|
||||
|
||||
if chat_model:
|
||||
return self._ai_summarize(messages, chat_model, chat_id)
|
||||
# 检查 chat_model 是否具有 get_answer 方法并且已经初始化了 message_summary
|
||||
if hasattr(chat_model, 'get_answer') and hasattr(chat_model, 'message_summary') and chat_model.message_summary:
|
||||
return self._ai_summarize(messages, chat_model, chat_id)
|
||||
else:
|
||||
self.LOG.warning(f"提供的 chat_model ({type(chat_model)}) 不支持基于数据库历史的总结或未正确初始化。将使用基础总结。")
|
||||
return self._basic_summarize(messages)
|
||||
else:
|
||||
return self._basic_summarize(messages)
|
||||
|
||||
|
||||
# ---- 修改 process_message_from_wxmsg ----
|
||||
def process_message_from_wxmsg(self, msg, wcf, all_contacts, bot_wxid=None):
|
||||
"""从微信消息对象中处理并记录与总结相关的文本消息
|
||||
记录所有群聊和私聊的文本(1)和App/卡片(49)消息。
|
||||
使用 XmlProcessor 提取用户实际输入的新内容或卡片标题。
|
||||
会自动跳过所有匹配 commands.registry 中定义的命令的消息。
|
||||
|
||||
Args:
|
||||
msg: 微信消息对象(WxMsg)
|
||||
wcf: 微信接口对象
|
||||
all_contacts: 所有联系人字典
|
||||
bot_wxid: 机器人自己的wxid,用于检测@机器人的消息
|
||||
bot_wxid: 机器人自己的wxid (必须提供以正确记录 sender_wxid)
|
||||
"""
|
||||
# 1. 基本筛选:只记录群聊中的、非自己发送的文本消息或App消息
|
||||
if not msg.from_group():
|
||||
return
|
||||
if msg.type != 0x01 and msg.type != 49: # 只记录文本消息和App消息(包括引用消息)
|
||||
return
|
||||
if msg.from_self():
|
||||
if msg.type != 0x01 and msg.type != 49:
|
||||
return
|
||||
|
||||
chat_id = msg.roomid
|
||||
# 原始消息内容用于命令和@匹配
|
||||
original_content = msg.content.strip()
|
||||
|
||||
# 2. 预先判断消息是否 @ 了机器人 (如果提供了 bot_wxid)
|
||||
is_at_message = False
|
||||
bot_name_in_group = "" # 初始化为空字符串
|
||||
if bot_wxid:
|
||||
bot_name_in_group = wcf.get_alias_in_chatroom(bot_wxid, chat_id)
|
||||
if not bot_name_in_group:
|
||||
bot_name_in_group = all_contacts.get(bot_wxid, "泡泡") # 使用通讯录或默认名
|
||||
|
||||
# 优化@检查:检查原始文本中是否包含 "@机器人昵称" (考虑特殊空格)
|
||||
mention_pattern_exact = f"@{re.escape(bot_name_in_group)}"
|
||||
mention_pattern_space = rf"@{re.escape(bot_name_in_group)}(\u2005|\s|$)"
|
||||
if mention_pattern_exact in original_content or re.search(mention_pattern_space, original_content):
|
||||
is_at_message = True
|
||||
|
||||
# 3. 检查消息是否匹配任何已定义的命令
|
||||
for command in COMMANDS:
|
||||
# 只关心在群聊生效的命令
|
||||
if command.scope in ["group", "both"]:
|
||||
match = command.pattern.search(original_content)
|
||||
if match:
|
||||
# 如果命令需要@,但消息实际上没有@机器人,则这不是一个有效的命令调用,继续检查下一个命令
|
||||
if command.need_at and not is_at_message:
|
||||
continue
|
||||
|
||||
# 如果命令不需要@,或者需要@且消息确实@了机器人,那么这就是一个命令调用,跳过记录
|
||||
self.LOG.debug(f"跳过匹配命令 '{command.name}' 的消息: {original_content[:30]}...")
|
||||
return # 直接返回,不记录此消息
|
||||
|
||||
# 4. 如果消息没有匹配任何命令,但确实@了机器人,也跳过记录
|
||||
# (防止记录类似 "你好 @机器人" 这样的非命令交互)
|
||||
if is_at_message:
|
||||
self.LOG.debug(f"跳过非命令但包含@机器人的消息: {original_content[:30]}...")
|
||||
chat_id = msg.roomid if msg.from_group() else msg.sender
|
||||
if not chat_id:
|
||||
self.LOG.warning(f"无法确定消息的chat_id (msg.id={msg.id}), 跳过记录")
|
||||
return
|
||||
|
||||
# 5. 使用 XmlProcessor 提取消息详情 (如果消息不是命令且没有@机器人)
|
||||
# ---- 获取 sender_wxid ----
|
||||
sender_wxid = msg.sender
|
||||
if not sender_wxid:
|
||||
# 理论上不应发生,但做个防护
|
||||
self.LOG.error(f"消息 (id={msg.id}) 缺少 sender wxid,无法记录!")
|
||||
return
|
||||
# ---- 获取 sender_wxid 结束 ----
|
||||
|
||||
# 确定发送者名称 (逻辑不变)
|
||||
sender_name = ""
|
||||
if msg.from_group():
|
||||
sender_name = wcf.get_alias_in_chatroom(sender_wxid, chat_id)
|
||||
if not sender_name:
|
||||
sender_name = all_contacts.get(sender_wxid, sender_wxid)
|
||||
else:
|
||||
if bot_wxid and sender_wxid == bot_wxid:
|
||||
sender_name = all_contacts.get(bot_wxid, "机器人")
|
||||
else:
|
||||
sender_name = all_contacts.get(sender_wxid, sender_wxid)
|
||||
|
||||
# 使用 XmlProcessor 提取消息详情 (逻辑不变)
|
||||
extracted_data = None
|
||||
try:
|
||||
# 传入原始 msg 对象
|
||||
extracted_data = self.xml_processor.extract_quoted_message(msg)
|
||||
if msg.from_group():
|
||||
extracted_data = self.xml_processor.extract_quoted_message(msg)
|
||||
else:
|
||||
extracted_data = self.xml_processor.extract_private_quoted_message(msg)
|
||||
except Exception as e:
|
||||
self.LOG.error(f"使用XmlProcessor提取消息内容时出错 (msg.id={msg.id}): {e}")
|
||||
return # 出错时,保守起见,不记录
|
||||
self.LOG.error(f"使用XmlProcessor提取消息内容时出错 (msg.id={msg.id}, type={msg.type}): {e}")
|
||||
if msg.type == 0x01 and not ("<" in msg.content and ">" in msg.content):
|
||||
content_to_record = msg.content.strip()
|
||||
source_info = "来自 纯文本消息 (XML解析失败后备)"
|
||||
self.LOG.warning(f"XML解析失败,但记录纯文本消息: {content_to_record[:50]}...")
|
||||
current_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
# 调用 record_message 时需要 sender_wxid
|
||||
self.record_message(chat_id, sender_name, sender_wxid, content_to_record, current_time_str)
|
||||
return
|
||||
|
||||
# 6. 确定要记录的内容 (content_to_record)
|
||||
# 确定要记录的内容 (content_to_record) - 复用之前的逻辑
|
||||
content_to_record = ""
|
||||
source_info = "未知来源"
|
||||
|
||||
# 优先使用提取到的新内容 (来自回复或普通文本或<title>)
|
||||
temp_new_content = extracted_data.get("new_content", "").strip()
|
||||
if temp_new_content:
|
||||
content_to_record = temp_new_content
|
||||
source_info = "来自 new_content (回复/文本/标题)"
|
||||
|
||||
|
||||
# 如果是引用类型消息,添加引用标记和引用内容的简略信息
|
||||
if extracted_data.get("has_quote", False):
|
||||
quoted_sender = extracted_data.get("quoted_sender", "")
|
||||
quoted_content = extracted_data.get("quoted_content", "")
|
||||
|
||||
|
||||
# 处理被引用内容
|
||||
if quoted_content:
|
||||
# 对较长的引用内容进行截断
|
||||
max_quote_length = 30
|
||||
if len(quoted_content) > max_quote_length:
|
||||
quoted_content = quoted_content[:max_quote_length] + "..."
|
||||
|
||||
|
||||
# 如果被引用的是卡片,则使用标准卡片格式
|
||||
if extracted_data.get("quoted_is_card", False):
|
||||
quoted_card_title = extracted_data.get("quoted_card_title", "")
|
||||
quoted_card_type = extracted_data.get("quoted_card_type", "")
|
||||
|
||||
|
||||
# 根据卡片类型确定内容类型
|
||||
card_type = "卡片"
|
||||
if "链接" in quoted_card_type or "消息" in quoted_card_type:
|
||||
@@ -404,82 +418,58 @@ class MessageSummary:
|
||||
card_type = "图片"
|
||||
elif "文件" in quoted_card_type:
|
||||
card_type = "文件"
|
||||
|
||||
|
||||
# 整个卡片内容包裹在【】中
|
||||
quoted_content = f"【{card_type}: {quoted_card_title}】"
|
||||
|
||||
|
||||
# 根据是否有被引用者信息构建引用前缀
|
||||
if quoted_sender:
|
||||
# 添加带引用人的引用格式,将新内容放在前面,引用内容放在后面
|
||||
content_to_record = f"{content_to_record} 【回复 {quoted_sender}:{quoted_content}】"
|
||||
else:
|
||||
# 仅添加引用内容,将新内容放在前面,引用内容放在后面
|
||||
content_to_record = f"{content_to_record} 【回复:{quoted_content}】"
|
||||
|
||||
# 其次,如果新内容为空,但这是一个卡片且有标题,则使用卡片标题
|
||||
elif extracted_data.get("is_card") and extracted_data.get("card_title", "").strip():
|
||||
# 卡片消息使用固定格式,包含标题和描述
|
||||
card_title = extracted_data.get("card_title", "").strip()
|
||||
card_description = extracted_data.get("card_description", "").strip()
|
||||
card_type = extracted_data.get("card_type", "")
|
||||
card_source = extracted_data.get("card_appname") or extracted_data.get("card_sourcedisplayname", "")
|
||||
|
||||
# 构建格式化的卡片内容,包含标题和描述
|
||||
# 根据卡片类型进行特殊处理
|
||||
if "链接" in card_type or "消息" in card_type:
|
||||
content_type = "链接"
|
||||
elif "视频" in card_type or "音乐" in card_type:
|
||||
content_type = "媒体"
|
||||
elif "位置" in card_type:
|
||||
content_type = "位置"
|
||||
elif "图片" in card_type:
|
||||
content_type = "图片"
|
||||
elif "文件" in card_type:
|
||||
content_type = "文件"
|
||||
else:
|
||||
content_type = "卡片"
|
||||
|
||||
# 构建完整卡片内容
|
||||
|
||||
if "链接" in card_type or "消息" in card_type: content_type = "链接"
|
||||
elif "视频" in card_type or "音乐" in card_type: content_type = "媒体"
|
||||
elif "位置" in card_type: content_type = "位置"
|
||||
elif "图片" in card_type: content_type = "图片"
|
||||
elif "文件" in card_type: content_type = "文件"
|
||||
else: content_type = "卡片"
|
||||
|
||||
card_content = f"{content_type}: {card_title}"
|
||||
|
||||
# 添加描述内容(如果有)
|
||||
if card_description:
|
||||
# 对较长的描述进行截断
|
||||
max_desc_length = 50
|
||||
if len(card_description) > max_desc_length:
|
||||
card_description = card_description[:max_desc_length] + "..."
|
||||
card_content += f" - {card_description}"
|
||||
|
||||
# 添加来源信息(如果有)
|
||||
if card_source:
|
||||
card_content += f" (来自:{card_source})"
|
||||
|
||||
# 将整个卡片内容包裹在【】中
|
||||
content_to_record = f"【{card_content}】"
|
||||
|
||||
source_info = "来自 卡片(标题+描述)"
|
||||
|
||||
# 普通文本消息的保底处理 (已在前面排除了命令和@消息)
|
||||
|
||||
# 普通文本消息的保底处理
|
||||
elif msg.type == 0x01 and not ("<" in msg.content and ">" in msg.content): # 再次确认是纯文本
|
||||
content_to_record = msg.content.strip() # 使用原始纯文本
|
||||
source_info = "来自 纯文本消息"
|
||||
|
||||
|
||||
# 7. 如果最终没有提取到有效内容,则不记录
|
||||
# 如果最终没有提取到有效内容,则不记录 (逻辑不变)
|
||||
if not content_to_record:
|
||||
# Debug日志级别调整为更详细,说明为何没有内容
|
||||
self.LOG.debug(f"未能提取到有效文本内容用于记录,跳过 (msg.id={msg.id}) - IsCard: {extracted_data.get('is_card', False)}, HasQuote: {extracted_data.get('has_quote', False)}")
|
||||
self.LOG.debug(f"未能提取到有效文本内容用于记录,跳过 (msg.id={msg.id}, type={msg.type}) - IsCard: {extracted_data.get('is_card', False)}, HasQuote: {extracted_data.get('has_quote', False)}")
|
||||
return
|
||||
|
||||
# 8. 获取发送者昵称
|
||||
sender_name = wcf.get_alias_in_chatroom(msg.sender, msg.roomid)
|
||||
if not sender_name: # 如果没有群昵称,尝试获取微信昵称
|
||||
sender_data = all_contacts.get(msg.sender)
|
||||
sender_name = sender_data if sender_data else msg.sender # 最后使用wxid
|
||||
# 获取当前时间字符串 (使用完整格式)
|
||||
current_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
# 获取当前时间(只用于记录,不再打印)
|
||||
current_time_str = time.strftime("%H:%M", time.localtime())
|
||||
|
||||
# 9. 记录提取到的有效内容
|
||||
self.LOG.debug(f"记录消息 (来源: {source_info}): '[{current_time_str}]{sender_name}: {content_to_record}' (来自 msg.id={msg.id})")
|
||||
self.record_message(chat_id, sender_name, content_to_record, current_time_str)
|
||||
# ---- 修改记录调用 ----
|
||||
self.LOG.debug(f"记录消息 (来源: {source_info}, 类型: {'群聊' if msg.from_group() else '私聊'}): '[{current_time_str}]{sender_name}({sender_wxid}): {content_to_record}' (来自 msg.id={msg.id})")
|
||||
# 调用 record_message 时传入 sender_wxid
|
||||
self.record_message(chat_id, sender_name, sender_wxid, content_to_record, current_time_str)
|
||||
# ---- 记录调用修改结束 ----
|
||||
# ---- process_message_from_wxmsg 修改结束 ----
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
"""图像生成功能模块
|
||||
|
||||
包含以下功能:
|
||||
- CogView: 智谱AI文生图
|
||||
- AliyunImage: 阿里云文生图
|
||||
- GeminiImage: 谷歌Gemini文生图
|
||||
"""
|
||||
|
||||
from .img_cogview import CogView
|
||||
from .img_aliyun_image import AliyunImage
|
||||
from .img_gemini_image import GeminiImage
|
||||
|
||||
__all__ = ['CogView', 'AliyunImage', 'GeminiImage']
|
||||
__all__ = ['AliyunImage', 'GeminiImage']
|
||||
@@ -1,99 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
import tempfile
|
||||
import time
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
class CogView():
|
||||
def __init__(self, conf: dict) -> None:
|
||||
self.api_key = conf.get("api_key")
|
||||
self.model = conf.get("model", "cogview-4-250304") # 默认使用最新模型
|
||||
self.quality = conf.get("quality", "standard")
|
||||
self.size = conf.get("size", "1024x1024")
|
||||
self.enable = conf.get("enable", True)
|
||||
|
||||
project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
default_img_dir = os.path.join(project_dir, "zhipuimg")
|
||||
self.temp_dir = conf.get("temp_dir", default_img_dir)
|
||||
|
||||
self.LOG = logging.getLogger("CogView")
|
||||
|
||||
if self.api_key:
|
||||
self.client = ZhipuAI(api_key=self.api_key)
|
||||
else:
|
||||
self.LOG.warning("未配置智谱API密钥,图像生成功能无法使用")
|
||||
self.client = None
|
||||
|
||||
os.makedirs(self.temp_dir, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def value_check(conf: dict) -> bool:
|
||||
if conf and conf.get("api_key") and conf.get("enable", True):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return 'CogView'
|
||||
|
||||
def generate_image(self, prompt: str) -> str:
|
||||
"""
|
||||
生成图像并返回图像URL
|
||||
|
||||
Args:
|
||||
prompt (str): 图像描述
|
||||
|
||||
Returns:
|
||||
str: 生成的图像URL或错误信息
|
||||
"""
|
||||
if not self.client or not self.enable:
|
||||
return "图像生成功能未启用或API密钥未配置"
|
||||
|
||||
try:
|
||||
response = self.client.images.generations(
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
quality=self.quality,
|
||||
size=self.size,
|
||||
)
|
||||
|
||||
if response and response.data and len(response.data) > 0:
|
||||
return response.data[0].url
|
||||
else:
|
||||
return "图像生成失败,未收到有效响应"
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
self.LOG.error(f"图像生成出错: {error_str}")
|
||||
|
||||
if "Error code: 500" in error_str or "HTTP/1.1 500" in error_str or "code\":\"1234\"" in error_str:
|
||||
self.LOG.warning(f"检测到违规内容请求: {prompt}")
|
||||
return "很抱歉,您的请求可能包含违规内容,无法生成图像"
|
||||
|
||||
return "图像生成失败,请调整您的描述后重试"
|
||||
|
||||
def download_image(self, image_url: str) -> str:
|
||||
"""
|
||||
下载图片并返回本地文件路径
|
||||
|
||||
Args:
|
||||
image_url (str): 图片URL
|
||||
|
||||
Returns:
|
||||
str: 本地图片文件路径,下载失败则返回None
|
||||
"""
|
||||
try:
|
||||
response = requests.get(image_url, stream=True, timeout=30)
|
||||
if response.status_code == 200:
|
||||
file_path = os.path.join(self.temp_dir, f"cogview_{int(time.time())}.jpg")
|
||||
with open(file_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
self.LOG.info(f"图片已下载到: {file_path}")
|
||||
return file_path
|
||||
else:
|
||||
self.LOG.error(f"下载图片失败,状态码: {response.status_code}")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.LOG.error(f"下载图片过程出错: {str(e)}")
|
||||
return None
|
||||
@@ -5,7 +5,7 @@ import shutil
|
||||
import time
|
||||
from wcferry import Wcf
|
||||
from configuration import Config
|
||||
from image import CogView, AliyunImage, GeminiImage
|
||||
from image import AliyunImage, GeminiImage
|
||||
|
||||
|
||||
class ImageGenerationManager:
|
||||
@@ -29,7 +29,6 @@ class ImageGenerationManager:
|
||||
self.send_text = send_text_callback
|
||||
|
||||
# 初始化图像生成服务
|
||||
self.cogview = None
|
||||
self.aliyun_image = None
|
||||
self.gemini_image = None
|
||||
|
||||
@@ -46,15 +45,7 @@ class ImageGenerationManager:
|
||||
self.LOG.info("谷歌Gemini图像生成功能已启用")
|
||||
except Exception as e:
|
||||
self.LOG.error(f"初始化谷歌Gemini图像生成服务失败: {e}")
|
||||
|
||||
# 初始化CogView服务
|
||||
if hasattr(self.config, 'COGVIEW') and self.config.COGVIEW.get('enable', False):
|
||||
try:
|
||||
self.cogview = CogView(self.config.COGVIEW)
|
||||
self.LOG.info("智谱CogView文生图功能已初始化")
|
||||
except Exception as e:
|
||||
self.LOG.error(f"初始化智谱CogView文生图服务失败: {str(e)}")
|
||||
|
||||
|
||||
# 初始化AliyunImage服务
|
||||
if hasattr(self.config, 'ALIYUN_IMAGE') and self.config.ALIYUN_IMAGE.get('enable', False):
|
||||
try:
|
||||
@@ -65,23 +56,13 @@ class ImageGenerationManager:
|
||||
|
||||
def handle_image_generation(self, service_type, prompt, receiver, at_user=None):
|
||||
"""处理图像生成请求的通用函数
|
||||
:param service_type: 服务类型,'cogview'/'aliyun'/'gemini'
|
||||
:param service_type: 服务类型,'aliyun'/'gemini'
|
||||
:param prompt: 图像生成提示词
|
||||
:param receiver: 接收者ID
|
||||
:param at_user: 被@的用户ID,用于群聊
|
||||
:return: 处理状态,True成功,False失败
|
||||
"""
|
||||
if service_type == 'cogview':
|
||||
if not self.cogview or not hasattr(self.config, 'COGVIEW') or not self.config.COGVIEW.get('enable', False):
|
||||
self.LOG.info(f"收到智谱文生图请求但功能未启用: {prompt}")
|
||||
fallback_to_chat = self.config.COGVIEW.get('fallback_to_chat', False) if hasattr(self.config, 'COGVIEW') else False
|
||||
if not fallback_to_chat:
|
||||
self.send_text("报一丝,智谱文生图功能没有开启,请联系管理员开启此功能。(可以贿赂他开启)", receiver, at_user)
|
||||
return True
|
||||
return False
|
||||
service = self.cogview
|
||||
wait_message = "正在生成图像,请稍等..."
|
||||
elif service_type == 'aliyun':
|
||||
if service_type == 'aliyun':
|
||||
if not self.aliyun_image or not hasattr(self.config, 'ALIYUN_IMAGE') or not self.config.ALIYUN_IMAGE.get('enable', False):
|
||||
self.LOG.info(f"收到阿里文生图请求但功能未启用: {prompt}")
|
||||
fallback_to_chat = self.config.ALIYUN_IMAGE.get('fallback_to_chat', False) if hasattr(self.config, 'ALIYUN_IMAGE') else False
|
||||
|
||||
@@ -14,16 +14,6 @@
|
||||
在`config.yaml`中进行以下配置才可以调用:
|
||||
|
||||
```yaml
|
||||
cogview: # -----智谱AI图像生成配置这行不填-----
|
||||
# 此API请参考 https://www.bigmodel.cn/dev/api/image-model/cogview
|
||||
enable: False # 是否启用图像生成功能,默认关闭,将False替换为true则开启,此模型可和其他模型同时运行。
|
||||
api_key: # 智谱API密钥,请填入您的API Key
|
||||
model: cogview-4-250304 # 模型编码,可选:cogview-4-250304、cogview-4、cogview-3-flash
|
||||
quality: standard # 生成质量,可选:standard(快速)、hd(高清)
|
||||
size: 1024x1024 # 图片尺寸,可自定义,需符合条件
|
||||
trigger_keyword: 牛智谱 # 触发图像生成的关键词
|
||||
temp_dir: # 临时文件存储目录,留空则默认使用项目目录下的zhipuimg文件夹,如果要更改,例如 D:/Pictures/temp 或 /home/user/temp
|
||||
fallback_to_chat: true # 当未启用绘画功能时:true=将请求发给聊天模型处理,false=回复固定的未启用提示信息
|
||||
|
||||
aliyun_image: # -----如果要使用阿里云文生图,取消下面的注释并填写相关内容,模型到阿里云百炼找通义万相-文生图2.1-Turbo-----
|
||||
enable: true # 是否启用阿里文生图功能,false为关闭,默认开启,如果未配置,则会将消息发送给聊天大模型
|
||||
|
||||
4
main.py
4
main.py
@@ -61,8 +61,8 @@ def main(chat_type: int):
|
||||
|
||||
robot.LOG.info(f"WeChatRobot【{__version__}】成功启动···")
|
||||
|
||||
# 机器人启动发送测试消息
|
||||
robot.sendTextMsg("机器人启动成功!", "filehelper")
|
||||
# # 机器人启动发送测试消息
|
||||
# robot.sendTextMsg("机器人启动成功!", "filehelper")
|
||||
|
||||
# 接收消息
|
||||
# robot.enableRecvMsg() # 可能会丢消息?
|
||||
|
||||
@@ -14,7 +14,5 @@ jupyter_client
|
||||
zhdate
|
||||
ipykernel
|
||||
google-generativeai>=0.3.0
|
||||
zhipuai>=1.0.0
|
||||
ollama
|
||||
dashscope
|
||||
google-genai
|
||||
142
robot.py
142
robot.py
@@ -9,22 +9,16 @@ from threading import Thread
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from ai_providers.ai_zhipu import ZhiPu
|
||||
from image import CogView, AliyunImage, GeminiImage
|
||||
from image import AliyunImage, GeminiImage
|
||||
from image.img_manager import ImageGenerationManager
|
||||
|
||||
from wcferry import Wcf, WxMsg
|
||||
|
||||
from ai_providers.ai_bard import BardAssistant
|
||||
from ai_providers.ai_chatglm import ChatGLM
|
||||
from ai_providers.ai_ollama import Ollama
|
||||
from ai_providers.ai_chatgpt import ChatGPT
|
||||
from ai_providers.ai_deepseek import DeepSeek
|
||||
from ai_providers.ai_perplexity import Perplexity
|
||||
from function.func_weather import Weather
|
||||
from function.func_news import News
|
||||
from ai_providers.ai_tigerbot import TigerBot
|
||||
from ai_providers.ai_xinghuo_web import XinghuoWeb
|
||||
from function.func_duel import start_duel, get_rank_list, get_player_stats, change_player_name, DuelManager, attempt_sneak_attack
|
||||
from function.func_summary import MessageSummary # 导入新的MessageSummary类
|
||||
from function.func_reminder import ReminderManager # 导入ReminderManager类
|
||||
@@ -47,74 +41,56 @@ class Robot(Job):
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, wcf: Wcf, chat_type: int) -> None:
|
||||
# 调用父类构造函数
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.wcf = wcf
|
||||
self.config = config
|
||||
self.LOG = logging.getLogger("Robot")
|
||||
self.wxid = self.wcf.get_self_wxid()
|
||||
self.wxid = self.wcf.get_self_wxid() # 获取机器人自己的wxid
|
||||
self.allContacts = self.getAllContacts()
|
||||
self._msg_timestamps = []
|
||||
# 创建决斗管理器
|
||||
self.duel_manager = DuelManager(self.sendDuelMsg)
|
||||
|
||||
# 初始化消息总结功能
|
||||
self.message_summary = MessageSummary(max_history=200)
|
||||
|
||||
# 初始化XML处理器
|
||||
|
||||
try:
|
||||
db_path = "data/message_history.db"
|
||||
# 使用 getattr 安全地获取 MAX_HISTORY,如果不存在则默认为 300
|
||||
max_hist = getattr(config, 'MAX_HISTORY', 300)
|
||||
self.message_summary = MessageSummary(max_history=max_hist, db_path=db_path)
|
||||
self.LOG.info(f"消息历史记录器已初始化 (max_history={self.message_summary.max_history})")
|
||||
except Exception as e:
|
||||
self.LOG.error(f"初始化 MessageSummary 失败: {e}", exc_info=True)
|
||||
self.message_summary = None # 保持失败时的处理
|
||||
|
||||
self.xml_processor = XmlProcessor(self.LOG)
|
||||
|
||||
# 初始化所有可能需要的AI模型实例
|
||||
|
||||
self.chat_models = {}
|
||||
self.LOG.info("开始初始化各种AI模型...")
|
||||
|
||||
# 初始化TigerBot
|
||||
if TigerBot.value_check(self.config.TIGERBOT):
|
||||
self.chat_models[ChatType.TIGER_BOT.value] = TigerBot(self.config.TIGERBOT)
|
||||
self.LOG.info(f"已加载 TigerBot 模型")
|
||||
|
||||
|
||||
# 初始化ChatGPT
|
||||
if ChatGPT.value_check(self.config.CHATGPT):
|
||||
self.chat_models[ChatType.CHATGPT.value] = ChatGPT(self.config.CHATGPT)
|
||||
self.LOG.info(f"已加载 ChatGPT 模型")
|
||||
|
||||
# 初始化讯飞星火
|
||||
if XinghuoWeb.value_check(self.config.XINGHUO_WEB):
|
||||
self.chat_models[ChatType.XINGHUO_WEB.value] = XinghuoWeb(self.config.XINGHUO_WEB)
|
||||
self.LOG.info(f"已加载 讯飞星火 模型")
|
||||
|
||||
# 初始化ChatGLM
|
||||
if ChatGLM.value_check(self.config.CHATGLM):
|
||||
try:
|
||||
# 检查key是否有实际内容而不只是存在
|
||||
if self.config.CHATGLM.get('key') and self.config.CHATGLM.get('key').strip():
|
||||
self.chat_models[ChatType.CHATGLM.value] = ChatGLM(self.config.CHATGLM)
|
||||
self.LOG.info(f"已加载 ChatGLM 模型")
|
||||
else:
|
||||
self.LOG.warning("ChatGLM 配置中缺少有效的API密钥,跳过初始化")
|
||||
# 传入 message_summary 和 wxid
|
||||
self.chat_models[ChatType.CHATGPT.value] = ChatGPT(
|
||||
self.config.CHATGPT,
|
||||
message_summary_instance=self.message_summary,
|
||||
bot_wxid=self.wxid
|
||||
)
|
||||
self.LOG.info(f"已加载 ChatGPT 模型")
|
||||
except Exception as e:
|
||||
self.LOG.error(f"初始化 ChatGLM 模型时出错: {str(e)}")
|
||||
|
||||
# 初始化BardAssistant
|
||||
if BardAssistant.value_check(self.config.BardAssistant):
|
||||
self.chat_models[ChatType.BardAssistant.value] = BardAssistant(self.config.BardAssistant)
|
||||
self.LOG.info(f"已加载 BardAssistant 模型")
|
||||
|
||||
# 初始化ZhiPu
|
||||
if ZhiPu.value_check(self.config.ZhiPu):
|
||||
self.chat_models[ChatType.ZhiPu.value] = ZhiPu(self.config.ZhiPu)
|
||||
self.LOG.info(f"已加载 智谱 模型")
|
||||
|
||||
# 初始化Ollama
|
||||
if Ollama.value_check(self.config.OLLAMA):
|
||||
self.chat_models[ChatType.OLLAMA.value] = Ollama(self.config.OLLAMA)
|
||||
self.LOG.info(f"已加载 Ollama 模型")
|
||||
self.LOG.error(f"初始化 ChatGPT 模型时出错: {str(e)}")
|
||||
|
||||
# 初始化DeepSeek
|
||||
if DeepSeek.value_check(self.config.DEEPSEEK):
|
||||
self.chat_models[ChatType.DEEPSEEK.value] = DeepSeek(self.config.DEEPSEEK)
|
||||
self.LOG.info(f"已加载 DeepSeek 模型")
|
||||
try:
|
||||
# 传入 message_summary 和 wxid
|
||||
self.chat_models[ChatType.DEEPSEEK.value] = DeepSeek(
|
||||
self.config.DEEPSEEK,
|
||||
message_summary_instance=self.message_summary,
|
||||
bot_wxid=self.wxid
|
||||
)
|
||||
self.LOG.info(f"已加载 DeepSeek 模型")
|
||||
except Exception as e:
|
||||
self.LOG.error(f"初始化 DeepSeek 模型时出错: {str(e)}")
|
||||
|
||||
# 初始化Perplexity
|
||||
if Perplexity.value_check(self.config.PERPLEXITY):
|
||||
@@ -277,40 +253,60 @@ class Robot(Job):
|
||||
Thread(target=innerProcessMsg, name="GetMessage", args=(self.wcf,), daemon=True).start()
|
||||
|
||||
def sendTextMsg(self, msg: str, receiver: str, at_list: str = "") -> None:
|
||||
""" 发送消息
|
||||
""" 发送消息并记录
|
||||
:param msg: 消息字符串
|
||||
:param receiver: 接收人wxid或者群id
|
||||
:param at_list: 要@的wxid, @所有人的wxid为:notify@all
|
||||
"""
|
||||
# 随机延迟0.3-1.3秒,并且一分钟内发送限制
|
||||
# 延迟和频率限制 (逻辑不变)
|
||||
time.sleep(float(str(time.time()).split('.')[-1][-2:]) / 100.0 + 0.3)
|
||||
now = time.time()
|
||||
if self.config.SEND_RATE_LIMIT > 0:
|
||||
# 清除超过1分钟的记录
|
||||
self._msg_timestamps = [t for t in self._msg_timestamps if now - t < 60]
|
||||
if len(self._msg_timestamps) >= self.config.SEND_RATE_LIMIT:
|
||||
self.LOG.warning(f"发送消息过快,已达到每分钟{self.config.SEND_RATE_LIMIT}条上限。")
|
||||
return
|
||||
self._msg_timestamps.append(now)
|
||||
|
||||
# msg 中需要有 @ 名单中一样数量的 @
|
||||
ats = ""
|
||||
message_to_send = msg # 保存原始消息用于记录
|
||||
if at_list:
|
||||
if at_list == "notify@all": # @所有人
|
||||
if at_list == "notify@all":
|
||||
ats = " @所有人"
|
||||
else:
|
||||
wxids = at_list.split(",")
|
||||
for wxid in wxids:
|
||||
# 根据 wxid 查找群昵称
|
||||
ats += f" @{self.wcf.get_alias_in_chatroom(wxid, receiver)}"
|
||||
for wxid_at in wxids: # Renamed variable
|
||||
ats += f" @{self.wcf.get_alias_in_chatroom(wxid_at, receiver)}"
|
||||
|
||||
# {msg}{ats} 表示要发送的消息内容后面紧跟@,例如 北京天气情况为:xxx @张三
|
||||
if ats == "":
|
||||
self.LOG.info(f"To {receiver}: {msg}")
|
||||
self.wcf.send_text(f"{msg}", receiver, at_list)
|
||||
else:
|
||||
self.LOG.info(f"To {receiver}:\n{ats}\n{msg}")
|
||||
self.wcf.send_text(f"{ats}\n\n{msg}", receiver, at_list)
|
||||
try:
|
||||
# 发送消息 (逻辑不变)
|
||||
if ats == "":
|
||||
self.LOG.info(f"To {receiver}: {msg}")
|
||||
self.wcf.send_text(f"{msg}", receiver, at_list)
|
||||
else:
|
||||
full_msg_content = f"{ats}\n\n{msg}"
|
||||
self.LOG.info(f"To {receiver}:\n{ats}\n{msg}")
|
||||
self.wcf.send_text(full_msg_content, receiver, at_list)
|
||||
|
||||
# ---- 修改记录逻辑 ----
|
||||
if self.message_summary: # 检查 message_summary 是否初始化成功
|
||||
# 确定机器人的名字
|
||||
robot_name = self.allContacts.get(self.wxid, "机器人")
|
||||
# 使用 self.wxid 作为 sender_wxid
|
||||
# 注意:这里不生成时间戳,让 record_message 内部生成
|
||||
self.message_summary.record_message(
|
||||
chat_id=receiver,
|
||||
sender_name=robot_name,
|
||||
sender_wxid=self.wxid, # 传入机器人自己的 wxid
|
||||
content=message_to_send
|
||||
)
|
||||
self.LOG.debug(f"已记录机器人发送的消息到 {receiver}")
|
||||
else:
|
||||
self.LOG.warning("MessageSummary 未初始化,无法记录发送的消息")
|
||||
# ---- 记录逻辑修改结束 ----
|
||||
|
||||
except Exception as e:
|
||||
self.LOG.error(f"发送消息失败: {e}")
|
||||
|
||||
def getAllContacts(self) -> dict:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user