mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-01-19 01:21:01 +08:00
avoid repeatedly instantiating bot
This commit is contained in:
@@ -7,16 +7,13 @@ from common.expired_dict import ExpiredDict
|
||||
import openai
|
||||
import time
|
||||
|
||||
if conf().get('expires_in_seconds'):
|
||||
all_sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
||||
else:
|
||||
all_sessions = dict()
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
proxy = conf().get('proxy')
|
||||
self.sessions=SessionManager()
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
@@ -26,16 +23,16 @@ class ChatGPTBot(Bot):
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
session_id = context.get('session_id') or context.get('from_user_id')
|
||||
if query == '#清除记忆':
|
||||
Session.clear_session(session_id)
|
||||
self.sessions.clear_session(session_id)
|
||||
return '记忆已清除'
|
||||
elif query == '#清除所有':
|
||||
Session.clear_all_session()
|
||||
self.sessions.clear_all_session()
|
||||
return '所有人记忆已清除'
|
||||
elif query == '#更新配置':
|
||||
load_config()
|
||||
return '配置已更新'
|
||||
|
||||
session = Session.build_session_query(query, session_id)
|
||||
session = self.sessions.build_session_query(query, session_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(session))
|
||||
|
||||
# if context.get('stream'):
|
||||
@@ -45,7 +42,7 @@ class ChatGPTBot(Bot):
|
||||
reply_content = self.reply_text(session, session_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"]))
|
||||
if reply_content["completion_tokens"] > 0:
|
||||
Session.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
return reply_content["content"]
|
||||
|
||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||
@@ -94,7 +91,7 @@ class ChatGPTBot(Bot):
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
logger.exception(e)
|
||||
Session.clear_session(session_id)
|
||||
self.sessions.clear_session(session_id)
|
||||
return {"completion_tokens": 0, "content": "请再问我一次吧"}
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
@@ -119,10 +116,11 @@ class ChatGPTBot(Bot):
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return None
|
||||
|
||||
class Session(object):
|
||||
@staticmethod
|
||||
def build_session_query(query, session_id):
|
||||
|
||||
class SessionManager(object):
|
||||
def __init__(self):
|
||||
self.sessions = {}
|
||||
def build_session_query(self,query, session_id):
|
||||
'''
|
||||
build query with conversation history
|
||||
e.g. [
|
||||
@@ -135,36 +133,33 @@ class Session(object):
|
||||
:param session_id: session id
|
||||
:return: query content with conversaction
|
||||
'''
|
||||
session = all_sessions.get(session_id, [])
|
||||
session = self.sessions.get(session_id, [])
|
||||
if len(session) == 0:
|
||||
system_prompt = conf().get("character_desc", "")
|
||||
system_item = {'role': 'system', 'content': system_prompt}
|
||||
session.append(system_item)
|
||||
all_sessions[session_id] = session
|
||||
self.sessions[session_id] = session
|
||||
user_item = {'role': 'user', 'content': query}
|
||||
session.append(user_item)
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
def save_session(answer, session_id, total_tokens):
|
||||
def save_session(self, answer, session_id, total_tokens):
|
||||
max_tokens = conf().get("conversation_max_tokens")
|
||||
if not max_tokens:
|
||||
# default 3000
|
||||
max_tokens = 1000
|
||||
max_tokens=int(max_tokens)
|
||||
|
||||
session = all_sessions.get(session_id)
|
||||
session = self.sessions.get(session_id)
|
||||
if session:
|
||||
# append conversation
|
||||
gpt_item = {'role': 'assistant', 'content': answer}
|
||||
session.append(gpt_item)
|
||||
|
||||
# discard exceed limit conversation
|
||||
Session.discard_exceed_conversation(session, max_tokens, total_tokens)
|
||||
self.discard_exceed_conversation(session, max_tokens, total_tokens)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def discard_exceed_conversation(session, max_tokens, total_tokens):
|
||||
def discard_exceed_conversation(self, session, max_tokens, total_tokens):
|
||||
dec_tokens = int(total_tokens)
|
||||
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
|
||||
while dec_tokens > max_tokens:
|
||||
@@ -176,10 +171,8 @@ class Session(object):
|
||||
break
|
||||
dec_tokens = dec_tokens - max_tokens
|
||||
|
||||
@staticmethod
|
||||
def clear_session(session_id):
|
||||
all_sessions[session_id] = []
|
||||
def clear_session(self,session_id):
|
||||
self.sessions[session_id] = []
|
||||
|
||||
@staticmethod
|
||||
def clear_all_session():
|
||||
all_sessions.clear()
|
||||
def clear_all_session(self):
|
||||
self.sessions.clear()
|
||||
|
||||
@@ -1,16 +1,25 @@
|
||||
from bot import bot_factory
|
||||
from common.singleton import singleton
|
||||
from voice import voice_factory
|
||||
|
||||
|
||||
@singleton
|
||||
class Bridge(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
self.bots = {
|
||||
"chat": bot_factory.create_bot("chatGPT"),
|
||||
"voice_to_text": voice_factory.create_voice("openai"),
|
||||
# "text_to_voice": voice_factory.create_voice("baidu")
|
||||
}
|
||||
try:
|
||||
self.bots["text_to_voice"] = voice_factory.create_voice("baidu")
|
||||
except ModuleNotFoundError as e:
|
||||
print(e)
|
||||
|
||||
def fetch_reply_content(self, query, context):
|
||||
return bot_factory.create_bot("chatGPT").reply(query, context)
|
||||
return self.bots["chat"].reply(query, context)
|
||||
|
||||
def fetch_voice_to_text(self, voiceFile):
|
||||
return voice_factory.create_voice("openai").voiceToText(voiceFile)
|
||||
return self.bots["voice_to_text"].voiceToText(voiceFile)
|
||||
|
||||
def fetch_text_to_voice(self, text):
|
||||
return voice_factory.create_voice("baidu").textToVoice(text)
|
||||
return self.bots["text_to_voice"].textToVoice(text)
|
||||
9
common/singleton.py
Normal file
9
common/singleton.py
Normal file
@@ -0,0 +1,9 @@
|
||||
def singleton(cls):
|
||||
instances = {}
|
||||
|
||||
def get_instance(*args, **kwargs):
|
||||
if cls not in instances:
|
||||
instances[cls] = cls(*args, **kwargs)
|
||||
return instances[cls]
|
||||
|
||||
return get_instance
|
||||
Reference in New Issue
Block a user