新增流式对话功能

This commit is contained in:
shehuiqiang
2023-03-30 01:29:54 +08:00
parent ad99c5909c
commit 6c21ddea5b
11 changed files with 788 additions and 214 deletions

View File

@@ -41,17 +41,13 @@ class SydneyBot(Chatbot):
break
ordered_messages.insert(0, message)
current_message_id = message.get('parentMessageId')
return ordered_messages
def pop_last_conversation(self):
self.conversations_cache[self.conversation_key]["messages"].pop()
async def ask(
async def ask_stream(
self,
prompt: str,
conversation_style: EdgeGPT.CONVERSATION_STYLE_TYPE = None,
message_id: str = None,
message_id: str = None
) -> dict:
# 开启新对话
self.chat_hub = SydneyHub(Conversation(
@@ -88,11 +84,32 @@ class SydneyBot(Chatbot):
async for final, response in self.chat_hub.ask_stream(
prompt=prompt,
conversation_style=conversation_style
):
if final:
try:
self.update_reply_cache(response["item"]["messages"][-1])
except Exception as e:
self.conversations_cache[self.conversation_key]["messages"].pop()
yield True, f"AI生成内容被微软内容过滤器拦截,已删除最后一次提问的记忆,请尝试使用其他文字描述问题,若AI依然无法正常回复,请清除全部记忆后再次尝试"
yield final, response
async def ask(
self,
prompt: str,
conversation_style: EdgeGPT.CONVERSATION_STYLE_TYPE = None,
message_id: str = None
) -> dict:
if self.chat_hub.wss:
if not self.chat_hub.wss.closed:
await self.chat_hub.wss.close()
async for final, response in self.ask_stream(
prompt=prompt,
conversation_style=conversation_style,
message_id=message_id
):
if final:
self.update_reply_cache(response["item"]["messages"][-1])
return response
self.chat_hub.wss.close()
def update_reply_cache(
self,

View File

@@ -1,7 +1,7 @@
# encoding:utf-8
import asyncio
from model.model import Model
from config import model_conf_val,common_conf_val
from config import model_conf_val, common_conf_val
from common import log
from EdgeGPT import Chatbot, ConversationStyle
from ImageGen import ImageGen
@@ -23,87 +23,85 @@ class BingModel(Model):
try:
self.cookies = model_conf_val("bing", "cookies")
self.jailbreak = model_conf_val("bing", "jailbreak")
self.bot = SydneyBot(cookies=self.cookies,options={}) if(self.jailbreak) else Chatbot(cookies=self.cookies)
self.bot = SydneyBot(cookies=self.cookies, options={}) if (
self.jailbreak) else Chatbot(cookies=self.cookies)
except Exception as e:
log.exception(e)
log.warn(e)
async def reply_text_stream(self, query: str, context=None) -> dict:
async def handle_answer(final, answer):
if final:
try:
reply = self.build_source_attributions(answer, context)
log.info("[NewBing] reply:{}",reply)
yield True, reply
except Exception as e:
log.warn(answer)
log.warn(e)
await user_session.get(context['from_user_id'], None).reset()
yield True, answer
else:
try:
yield False, answer
except Exception as e:
log.warn(answer)
log.warn(e)
await user_session.get(context['from_user_id'], None).reset()
yield True, answer
if not context or not context.get('type') or context.get('type') == 'TEXT':
clear_memory_commands = common_conf_val(
'clear_memory_commands', ['#清除记忆'])
if query in clear_memory_commands:
user_session[context['from_user_id']] = None
yield True, '记忆已清除'
bot = user_session.get(context['from_user_id'], None)
if not bot:
bot = self.bot
else:
query = self.get_quick_ask_query(query, context)
user_session[context['from_user_id']] = bot
log.info("[NewBing] query={}".format(query))
if self.jailbreak:
async for final, answer in bot.ask_stream(query, conversation_style=self.style, message_id=bot.user_message_id):
async for result in handle_answer(final, answer):
yield result
else:
async for final, answer in bot.ask_stream(query, conversation_style=self.style):
async for result in handle_answer(final, answer):
yield result
def reply(self, query: str, context=None) -> tuple[str, dict]:
if not context or not context.get('type') or context.get('type') == 'TEXT':
clear_memory_commands = common_conf_val('clear_memory_commands', ['#清除记忆'])
clear_memory_commands = common_conf_val(
'clear_memory_commands', ['#清除记忆'])
if query in clear_memory_commands:
user_session[context['from_user_id']]=None
user_session[context['from_user_id']] = None
return '记忆已清除'
bot = user_session.get(context['from_user_id'], None)
if (bot == None):
bot = self.bot
else:
if (len(query) == 1 and query.isdigit() and query != "0"):
suggestion_dict = suggestion_session[context['from_user_id']]
if (suggestion_dict != None):
query = suggestion_dict[int(query)-1]
if (query == None):
return "输入的序号不在建议列表范围中"
else:
query = "在上面的基础上,"+query
query = self.get_quick_ask_query(query, context)
user_session[context['from_user_id']] = bot
log.info("[NewBing] query={}".format(query))
if(self.jailbreak):
task = bot.ask(query, conversation_style=self.style,message_id=bot.user_message_id)
if (self.jailbreak):
task = bot.ask(query, conversation_style=self.style,
message_id=bot.user_message_id)
else:
task = bot.ask(query, conversation_style=self.style)
try:
answer = asyncio.run(task)
except Exception as e:
bot.pop_last_conversation()
log.exception(answer)
return f"AI生成内容被微软内容过滤器拦截,已删除最后一次提问的记忆,请尝试使用其他文字描述问题,若AI依然无法正常回复,请使用{clear_memory_commands[0]}命令清除全部记忆"
# 最新一条回复
answer = asyncio.run(task)
if isinstance(answer, str):
return answer
try:
reply = answer["item"]["messages"][-1]
except Exception as e:
self.reset_chat(context['from_user_id'])
log.exception(answer)
user_session.get(context['from_user_id'], None).reset()
log.warn(answer)
return "本轮对话已超时,已开启新的一轮对话,请重新提问。"
reply_text = reply["text"]
reference = ""
if "sourceAttributions" in reply:
for i, attribution in enumerate(reply["sourceAttributions"]):
display_name = attribution["providerDisplayName"]
url = attribution["seeMoreUrl"]
reference += f"{i+1}、[{display_name}]({url})\n\n"
if len(reference) > 0:
reference = "***\n"+reference
suggestion = ""
if "suggestedResponses" in reply:
suggestion_dict = dict()
for i, attribution in enumerate(reply["suggestedResponses"]):
suggestion_dict[i] = attribution["text"]
suggestion += f">{i+1}{attribution['text']}\n\n"
suggestion_session[context['from_user_id']
] = suggestion_dict
if len(suggestion) > 0:
suggestion = "***\n你可以通过输入序号快速追问我以下建议问题:\n\n"+suggestion
throttling = answer["item"]["throttling"]
throttling_str = ""
if throttling["numUserMessagesInConversation"] == throttling["maxNumUserMessagesInConversation"]:
self.reset_chat(context['from_user_id'])
throttling_str = "(对话轮次已达上限,本次聊天已结束,将开启新的对话)"
else:
throttling_str = f"对话轮次: {throttling['numUserMessagesInConversation']}/{throttling['maxNumUserMessagesInConversation']}\n"
response = f"{reply_text}\n{reference}\n{suggestion}\n***\n{throttling_str}"
log.info("[NewBing] reply={}", response)
user_session[context['from_user_id']] = bot
return response
else:
self.reset_chat(context['from_user_id'])
log.warn("[NewBing] reply={}", answer)
return "对话被接口拒绝,已开启新的一轮对话。"
return self.build_source_attributions(answer, context)
elif context.get('type', None) == 'IMAGE_CREATE':
if functions.contain_chinese(query):
return "ImageGen目前仅支持使用英文关键词生成图片"
@@ -118,8 +116,58 @@ class BingModel(Model):
log.info("[NewBing] image_list={}".format(img_list))
return img_list
except Exception as e:
log.exception(e)
log.warn(e)
return "输入的内容可能违反微软的图片生成内容策略。过多的策略冲突可能会导致你被暂停访问。"
def reset_chat(self, from_user_id):
asyncio.run(user_session.get(from_user_id, None).reset())
def get_quick_ask_query(self, query, context):
if (len(query) == 1 and query.isdigit() and query != "0"):
suggestion_dict = suggestion_session[context['from_user_id']]
if (suggestion_dict != None):
query = suggestion_dict[int(query)-1]
if (query == None):
return "输入的序号不在建议列表范围中"
else:
query = "在上面的基础上,"+query
return query
def build_source_attributions(self, answer, context):
reference = ""
reply = answer["item"]["messages"][-1]
reply_text = reply["text"]
if "sourceAttributions" in reply:
for i, attribution in enumerate(reply["sourceAttributions"]):
display_name = attribution["providerDisplayName"]
url = attribution["seeMoreUrl"]
reference += f"{i+1}、[{display_name}]({url})\n\n"
if len(reference) > 0:
reference = "***\n"+reference
suggestion = ""
if "suggestedResponses" in reply:
suggestion_dict = dict()
for i, attribution in enumerate(reply["suggestedResponses"]):
suggestion_dict[i] = attribution["text"]
suggestion += f">{i+1}{attribution['text']}\n\n"
suggestion_session[context['from_user_id']
] = suggestion_dict
if len(suggestion) > 0:
suggestion = "***\n你可以通过输入序号快速追问我以下建议问题:\n\n"+suggestion
throttling = answer["item"]["throttling"]
throttling_str = ""
if throttling["numUserMessagesInConversation"] == throttling["maxNumUserMessagesInConversation"]:
user_session.get(context['from_user_id'], None).reset()
throttling_str = "(对话轮次已达上限,本次聊天已结束,将开启新的对话)"
else:
throttling_str = f"对话轮次: {throttling['numUserMessagesInConversation']}/{throttling['maxNumUserMessagesInConversation']}\n"
response = f"{reply_text}\n{reference}\n{suggestion}\n***\n{throttling_str}"
log.info("[NewBing] reply={}", response)
return response
else:
user_session.get(context['from_user_id'], None).reset()
log.warn("[NewBing] reply={}", answer)
return "对话被接口拒绝,已开启新的一轮对话。"

View File

@@ -83,20 +83,32 @@ class ChatGPTModel(Model):
return "请再问我一次吧"
def reply_text_stream(self, query, new_query, user_id, retry_count=0):
async def reply_text_stream(self, query, context, retry_count=0):
try:
res = openai.Completion.create(
model="text-davinci-003", # 对话模型的名称
prompt=new_query,
user_id=context['from_user_id']
new_query = Session.build_session_query(query, user_id)
res = openai.ChatCompletion.create(
model= model_conf(const.OPEN_AI).get("model") or "gpt-3.5-turbo", # 对话模型的名称
messages=new_query,
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
#max_tokens=4096, # 回复最大的字符数
top_p=1,
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
stop=["\n\n\n"],
stream=True
)
return self._process_reply_stream(query, res, user_id)
full_response = ""
for chunk in res:
log.debug(chunk)
if (chunk["choices"][0]["finish_reason"]=="stop"):
break
chunk_message = chunk['choices'][0]['delta'].get("content")
if(chunk_message):
full_response+=chunk_message
yield False,full_response
Session.save_session(query, full_response, user_id)
log.info("[chatgpt]: reply={}", full_response)
yield True,full_response
except openai.error.RateLimitError as e:
# rate limit exception
@@ -104,45 +116,22 @@ class ChatGPTModel(Model):
if retry_count < 1:
time.sleep(5)
log.warn("[CHATGPT] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text_stream(query, user_id, retry_count+1)
yield True, self.reply_text_stream(query, user_id, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
yield True, "提问太快啦,请休息一下再问我吧"
except openai.error.APIConnectionError as e:
log.warn(e)
log.warn("[CHATGPT] APIConnection failed")
return "我连接不到网络,请稍后重试"
yield True, "我连接不到网络,请稍后重试"
except openai.error.Timeout as e:
log.warn(e)
log.warn("[CHATGPT] Timeout")
return "我没有收到消息,请稍后重试"
yield True, "我没有收到消息,请稍后重试"
except Exception as e:
# unknown exception
log.exception(e)
Session.clear_session(user_id)
return "请再问我一次吧"
def _process_reply_stream(
self,
query: str,
reply: dict,
user_id: str
) -> str:
full_response = ""
for response in reply:
if response.get("choices") is None or len(response["choices"]) == 0:
raise Exception("OpenAI API returned no choices")
if response["choices"][0].get("finish_details") is not None:
break
if response["choices"][0].get("text") is None:
raise Exception("OpenAI API returned no text")
if response["choices"][0]["text"] == "<|endoftext|>":
break
yield response["choices"][0]["text"]
full_response += response["choices"][0]["text"]
if query and full_response:
Session.save_session(query, full_response, user_id)
yield True, "请再问我一次吧"
def create_img(self, query, retry_count=0):
try:

View File

@@ -13,7 +13,9 @@ user_session = dict()
class OpenAIModel(Model):
def __init__(self):
openai.api_key = model_conf(const.OPEN_AI).get('api_key')
proxy = model_conf(const.OPEN_AI).get('proxy')
if proxy:
openai.proxy = proxy
def reply(self, query, context=None):
# acquire reply content
@@ -72,36 +74,55 @@ class OpenAIModel(Model):
return "请再问我一次吧"
def reply_text_stream(self, query, new_query, user_id, retry_count=0):
async def reply_text_stream(self, query, context, retry_count=0):
try:
user_id=context['from_user_id']
new_query = Session.build_session_query(query, user_id)
res = openai.Completion.create(
model="text-davinci-003", # 对话模型的名称
model= "text-davinci-003", # 对话模型的名称
prompt=new_query,
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
max_tokens=1200, # 回复最大的字符数
#max_tokens=4096, # 回复最大的字符数
top_p=1,
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
stop=["\n\n\n"],
stream=True
)
return self._process_reply_stream(query, res, user_id)
full_response = ""
for chunk in res:
log.debug(chunk)
if (chunk["choices"][0]["finish_reason"]=="stop"):
break
chunk_message = chunk['choices'][0].get("text")
if(chunk_message):
full_response+=chunk_message
yield False,full_response
Session.save_session(query, full_response, user_id)
log.info("[chatgpt]: reply={}", full_response)
yield True,full_response
except openai.error.RateLimitError as e:
# rate limit exception
log.warn(e)
if retry_count < 1:
time.sleep(5)
log.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, user_id, retry_count+1)
log.warn("[CHATGPT] RateLimit exceed, 第{}次重试".format(retry_count+1))
yield True, self.reply_text_stream(query, user_id, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
yield True, "提问太快啦,请休息一下再问我吧"
except openai.error.APIConnectionError as e:
log.warn(e)
log.warn("[CHATGPT] APIConnection failed")
yield True, "我连接不到网络,请稍后重试"
except openai.error.Timeout as e:
log.warn(e)
log.warn("[CHATGPT] Timeout")
yield True, "我没有收到消息,请稍后重试"
except Exception as e:
# unknown exception
log.exception(e)
Session.clear_session(user_id)
return "请再问我一次吧"
yield True, "请再问我一次吧"
def _process_reply_stream(
self,