feat: support terminal channel

This commit is contained in:
zhayujie
2023-02-18 16:58:04 +08:00
parent bce946a1ef
commit 2ce3643237
11 changed files with 190 additions and 40 deletions

View File

@@ -3,7 +3,7 @@
from model.model import Model
from config import model_conf
from common import const
from common.log import logger
from common import log
import openai
import time
@@ -18,17 +18,21 @@ class OpenAIModel(Model):
def reply(self, query, context=None):
# acquire reply content
if not context or not context.get('type') or context.get('type') == 'TEXT':
logger.info("[OPEN_AI] query={}".format(query))
log.info("[OPEN_AI] query={}".format(query))
from_user_id = context['from_user_id']
if query == '#清除记忆':
Session.clear_session(from_user_id)
return '记忆已清除'
new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query))
log.debug("[OPEN_AI] session query={}".format(new_query))
if context.get('stream'):
# reply in stream
return self.reply_text_stream(query, new_query, from_user_id)
reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
log.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content and query:
Session.save_session(query, reply_content, from_user_id)
return reply_content
@@ -49,45 +53,98 @@ class OpenAIModel(Model):
stop=["\n\n\n"]
)
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
logger.info("[OPEN_AI] reply={}".format(res_content))
log.info("[OPEN_AI] reply={}".format(res_content))
return res_content
except openai.error.RateLimitError as e:
# rate limit exception
logger.warn(e)
log.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
log.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, user_id, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e:
# unknown exception
logger.exception(e)
log.exception(e)
Session.clear_session(user_id)
return "请再问我一次吧"
def reply_text_stream(self, query, new_query, user_id, retry_count=0):
try:
res = openai.Completion.create(
model="text-davinci-003", # 对话模型的名称
prompt=new_query,
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
max_tokens=1200, # 回复最大的字符数
top_p=1,
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
stop=["\n\n\n"],
stream=True
)
return self._process_reply_stream(query, res, user_id)
except openai.error.RateLimitError as e:
# rate limit exception
log.warn(e)
if retry_count < 1:
time.sleep(5)
log.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, user_id, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e:
# unknown exception
log.exception(e)
Session.clear_session(user_id)
return "请再问我一次吧"
def _process_reply_stream(
self,
query: str,
reply: dict,
user_id: str
) -> str:
full_response = ""
for response in reply:
if response.get("choices") is None or len(response["choices"]) == 0:
raise Exception("OpenAI API returned no choices")
if response["choices"][0].get("finish_details") is not None:
break
if response["choices"][0].get("text") is None:
raise Exception("OpenAI API returned no text")
if response["choices"][0]["text"] == "<|endoftext|>":
break
yield response["choices"][0]["text"]
full_response += response["choices"][0]["text"]
if query and full_response:
Session.save_session(query, full_response, user_id)
def create_img(self, query, retry_count=0):
try:
logger.info("[OPEN_AI] image_query={}".format(query))
log.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create(
prompt=query, #图片描述
n=1, #每次生成图片的数量
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
log.info("[OPEN_AI] image_url={}".format(image_url))
return image_url
except openai.error.RateLimitError as e:
logger.warn(e)
log.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
log.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e:
logger.exception(e)
log.exception(e)
return None
@@ -125,8 +182,8 @@ class Session(object):
conversation["question"] = query
conversation["answer"] = answer
session = user_session.get(user_id)
logger.debug(conversation)
logger.debug(session)
log.debug(conversation)
log.debug(session)
if session:
# append conversation
session.append(conversation)