mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-01-19 01:21:01 +08:00
feat: image input and session optimize
This commit is contained in:
@@ -13,6 +13,9 @@ from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, pconf
|
||||
import threading
|
||||
from common import memory, utils
|
||||
import base64
|
||||
|
||||
|
||||
class LinkAIBot(Bot):
|
||||
# authentication failed
|
||||
@@ -21,7 +24,7 @@ class LinkAIBot(Bot):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.args = {}
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
@@ -61,17 +64,25 @@ class LinkAIBot(Bot):
|
||||
linkai_api_key = conf().get("linkai_api_key")
|
||||
|
||||
session_id = context["session_id"]
|
||||
session_message = self.sessions.session_msg_query(query, session_id)
|
||||
logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}")
|
||||
|
||||
# image process
|
||||
img_cache = memory.USER_IMAGE_CACHE.get(session_id)
|
||||
if img_cache:
|
||||
messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache)
|
||||
if messages:
|
||||
session_message = messages
|
||||
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
model = conf().get("model")
|
||||
# remove system message
|
||||
if session.messages[0].get("role") == "system":
|
||||
if session_message[0].get("role") == "system":
|
||||
if app_code or model == "wenxin":
|
||||
session.messages.pop(0)
|
||||
session_message.pop(0)
|
||||
|
||||
body = {
|
||||
"app_code": app_code,
|
||||
"messages": session.messages,
|
||||
"messages": session_message,
|
||||
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"temperature": conf().get("temperature"),
|
||||
"top_p": conf().get("top_p", 1),
|
||||
@@ -94,7 +105,7 @@ class LinkAIBot(Bot):
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
|
||||
|
||||
agent_suffix = self._fetch_agent_suffix(response)
|
||||
if agent_suffix:
|
||||
@@ -130,6 +141,54 @@ class LinkAIBot(Bot):
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict):
|
||||
try:
|
||||
enable_image_input = False
|
||||
app_info = self._fetch_app_info(app_code)
|
||||
if not app_info:
|
||||
logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}")
|
||||
return None
|
||||
plugins = app_info.get("data").get("plugins")
|
||||
for plugin in plugins:
|
||||
if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"):
|
||||
enable_image_input = True
|
||||
if not enable_image_input:
|
||||
return
|
||||
msg = img_cache.get("msg")
|
||||
path = img_cache.get("path")
|
||||
msg.prepare()
|
||||
logger.info(f"[LinkAI] query with images, path={path}")
|
||||
messages = self._build_vision_msg(query, path)
|
||||
memory.USER_IMAGE_CACHE[session_id] = None
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
def _build_vision_msg(self, query: str, path: str):
|
||||
try:
|
||||
suffix = utils.get_path_suffix(path)
|
||||
with open(path, "rb") as file:
|
||||
base64_str = base64.b64encode(file.read()).decode('utf-8')
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": query
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/{suffix};base64,{base64_str}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}]
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
@@ -195,6 +254,16 @@ class LinkAIBot(Bot):
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self.reply_text(session, app_code, retry_count + 1)
|
||||
|
||||
def _fetch_app_info(self, app_code: str):
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
params = {"app_code": app_code}
|
||||
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
|
||||
if res.status_code == 200:
|
||||
return res.json()
|
||||
else:
|
||||
logger.warning(f"[LinkAI] find app info exception, res={res}")
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
try:
|
||||
@@ -239,6 +308,7 @@ class LinkAIBot(Bot):
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
def _fetch_agent_suffix(self, response):
|
||||
try:
|
||||
plugin_list = []
|
||||
@@ -275,4 +345,44 @@ class LinkAIBot(Bot):
|
||||
reply = Reply(ReplyType.IMAGE_URL, url)
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(e)
|
||||
|
||||
|
||||
class LinkAISessionManager(SessionManager):
|
||||
def session_msg_query(self, query, session_id):
|
||||
session = self.build_session(session_id)
|
||||
messages = session.messages + [{"role": "user", "content": query}]
|
||||
return messages
|
||||
|
||||
def session_reply(self, reply, session_id, total_tokens=None, query=None):
|
||||
session = self.build_session(session_id)
|
||||
if query:
|
||||
session.add_query(query)
|
||||
session.add_reply(reply)
|
||||
try:
|
||||
max_tokens = conf().get("conversation_max_tokens", 2500)
|
||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||
logger.info(f"[LinkAI] chat history discard, before tokens={total_tokens}, now tokens={tokens_cnt}")
|
||||
except Exception as e:
|
||||
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
|
||||
class LinkAISession(ChatGPTSession):
|
||||
def calc_tokens(self):
|
||||
try:
|
||||
cur_tokens = super().calc_tokens()
|
||||
except Exception as e:
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
cur_tokens = len(str(self.messages))
|
||||
return cur_tokens
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
cur_tokens = self.calc_tokens()
|
||||
if cur_tokens > max_tokens:
|
||||
for i in range(0, len(self.messages)):
|
||||
if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user":
|
||||
self.messages.pop(i)
|
||||
self.messages.pop(i - 1)
|
||||
return self.calc_tokens()
|
||||
return cur_tokens
|
||||
|
||||
@@ -69,7 +69,7 @@ class SessionManager(object):
|
||||
total_tokens = session.discard_exceeding(max_tokens, None)
|
||||
logger.debug("prompt tokens used={}".format(total_tokens))
|
||||
except Exception as e:
|
||||
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
||||
logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def session_reply(self, reply, session_id, total_tokens=None):
|
||||
@@ -80,7 +80,7 @@ class SessionManager(object):
|
||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
|
||||
except Exception as e:
|
||||
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def clear_session(self, session_id):
|
||||
|
||||
@@ -9,8 +9,7 @@ from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.channel import Channel
|
||||
from common.dequeue import Dequeue
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from common import memory
|
||||
from plugins import *
|
||||
|
||||
try:
|
||||
@@ -205,14 +204,16 @@ class ChatChannel(Channel):
|
||||
else:
|
||||
return
|
||||
elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑
|
||||
cmsg = context["msg"]
|
||||
cmsg.prepare()
|
||||
memory.USER_IMAGE_CACHE[context["session_id"]] = {
|
||||
"path": context.content,
|
||||
"msg": context.get("msg")
|
||||
}
|
||||
elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑
|
||||
pass
|
||||
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
|
||||
pass
|
||||
else:
|
||||
logger.error("[WX] unknown context type: {}".format(context.type))
|
||||
logger.warning("[WX] unknown context type: {}".format(context.type))
|
||||
return
|
||||
return reply
|
||||
|
||||
|
||||
3
common/memory.py
Normal file
3
common/memory.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from common.expired_dict import ExpiredDict
|
||||
|
||||
USER_IMAGE_CACHE = ExpiredDict(60 * 3)
|
||||
@@ -1,6 +1,6 @@
|
||||
import io
|
||||
import os
|
||||
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@@ -49,3 +49,8 @@ def split_string_by_utf8_length(string, max_length, max_split=0):
|
||||
result.append(encoded[start:end].decode("utf-8"))
|
||||
start = end
|
||||
return result
|
||||
|
||||
|
||||
def get_path_suffix(path):
|
||||
path = urlparse(path).path
|
||||
return os.path.splitext(path)[-1].lstrip('.')
|
||||
|
||||
@@ -91,5 +91,4 @@ class LinkSummary:
|
||||
for support_url in support_list:
|
||||
if url.strip().startswith(support_url):
|
||||
return True
|
||||
logger.debug(f"[LinkSum] unsupported url, no need to process, url={url}")
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user