mirror of
https://github.com/zhayujie/bot-on-anything.git
synced 2026-03-02 08:10:07 +08:00
新增流式对话功能
This commit is contained in:
@@ -41,17 +41,13 @@ class SydneyBot(Chatbot):
|
||||
break
|
||||
ordered_messages.insert(0, message)
|
||||
current_message_id = message.get('parentMessageId')
|
||||
|
||||
return ordered_messages
|
||||
|
||||
def pop_last_conversation(self):
|
||||
self.conversations_cache[self.conversation_key]["messages"].pop()
|
||||
|
||||
async def ask(
|
||||
async def ask_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
conversation_style: EdgeGPT.CONVERSATION_STYLE_TYPE = None,
|
||||
message_id: str = None,
|
||||
message_id: str = None
|
||||
) -> dict:
|
||||
# 开启新对话
|
||||
self.chat_hub = SydneyHub(Conversation(
|
||||
@@ -88,11 +84,32 @@ class SydneyBot(Chatbot):
|
||||
async for final, response in self.chat_hub.ask_stream(
|
||||
prompt=prompt,
|
||||
conversation_style=conversation_style
|
||||
):
|
||||
if final:
|
||||
try:
|
||||
self.update_reply_cache(response["item"]["messages"][-1])
|
||||
except Exception as e:
|
||||
self.conversations_cache[self.conversation_key]["messages"].pop()
|
||||
yield True, f"AI生成内容被微软内容过滤器拦截,已删除最后一次提问的记忆,请尝试使用其他文字描述问题,若AI依然无法正常回复,请清除全部记忆后再次尝试"
|
||||
yield final, response
|
||||
|
||||
async def ask(
|
||||
self,
|
||||
prompt: str,
|
||||
conversation_style: EdgeGPT.CONVERSATION_STYLE_TYPE = None,
|
||||
message_id: str = None
|
||||
) -> dict:
|
||||
if self.chat_hub.wss:
|
||||
if not self.chat_hub.wss.closed:
|
||||
await self.chat_hub.wss.close()
|
||||
async for final, response in self.ask_stream(
|
||||
prompt=prompt,
|
||||
conversation_style=conversation_style,
|
||||
message_id=message_id
|
||||
):
|
||||
if final:
|
||||
self.update_reply_cache(response["item"]["messages"][-1])
|
||||
return response
|
||||
self.chat_hub.wss.close()
|
||||
|
||||
def update_reply_cache(
|
||||
self,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# encoding:utf-8
|
||||
import asyncio
|
||||
from model.model import Model
|
||||
from config import model_conf_val,common_conf_val
|
||||
from config import model_conf_val, common_conf_val
|
||||
from common import log
|
||||
from EdgeGPT import Chatbot, ConversationStyle
|
||||
from ImageGen import ImageGen
|
||||
@@ -23,87 +23,85 @@ class BingModel(Model):
|
||||
try:
|
||||
self.cookies = model_conf_val("bing", "cookies")
|
||||
self.jailbreak = model_conf_val("bing", "jailbreak")
|
||||
self.bot = SydneyBot(cookies=self.cookies,options={}) if(self.jailbreak) else Chatbot(cookies=self.cookies)
|
||||
self.bot = SydneyBot(cookies=self.cookies, options={}) if (
|
||||
self.jailbreak) else Chatbot(cookies=self.cookies)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.warn(e)
|
||||
|
||||
async def reply_text_stream(self, query: str, context=None) -> dict:
|
||||
async def handle_answer(final, answer):
|
||||
if final:
|
||||
try:
|
||||
reply = self.build_source_attributions(answer, context)
|
||||
log.info("[NewBing] reply:{}",reply)
|
||||
yield True, reply
|
||||
except Exception as e:
|
||||
log.warn(answer)
|
||||
log.warn(e)
|
||||
await user_session.get(context['from_user_id'], None).reset()
|
||||
yield True, answer
|
||||
else:
|
||||
try:
|
||||
yield False, answer
|
||||
except Exception as e:
|
||||
log.warn(answer)
|
||||
log.warn(e)
|
||||
await user_session.get(context['from_user_id'], None).reset()
|
||||
yield True, answer
|
||||
|
||||
if not context or not context.get('type') or context.get('type') == 'TEXT':
|
||||
clear_memory_commands = common_conf_val(
|
||||
'clear_memory_commands', ['#清除记忆'])
|
||||
if query in clear_memory_commands:
|
||||
user_session[context['from_user_id']] = None
|
||||
yield True, '记忆已清除'
|
||||
bot = user_session.get(context['from_user_id'], None)
|
||||
if not bot:
|
||||
bot = self.bot
|
||||
else:
|
||||
query = self.get_quick_ask_query(query, context)
|
||||
user_session[context['from_user_id']] = bot
|
||||
log.info("[NewBing] query={}".format(query))
|
||||
if self.jailbreak:
|
||||
async for final, answer in bot.ask_stream(query, conversation_style=self.style, message_id=bot.user_message_id):
|
||||
async for result in handle_answer(final, answer):
|
||||
yield result
|
||||
else:
|
||||
async for final, answer in bot.ask_stream(query, conversation_style=self.style):
|
||||
async for result in handle_answer(final, answer):
|
||||
yield result
|
||||
|
||||
def reply(self, query: str, context=None) -> tuple[str, dict]:
|
||||
if not context or not context.get('type') or context.get('type') == 'TEXT':
|
||||
clear_memory_commands = common_conf_val('clear_memory_commands', ['#清除记忆'])
|
||||
clear_memory_commands = common_conf_val(
|
||||
'clear_memory_commands', ['#清除记忆'])
|
||||
if query in clear_memory_commands:
|
||||
user_session[context['from_user_id']]=None
|
||||
user_session[context['from_user_id']] = None
|
||||
return '记忆已清除'
|
||||
bot = user_session.get(context['from_user_id'], None)
|
||||
if (bot == None):
|
||||
bot = self.bot
|
||||
else:
|
||||
if (len(query) == 1 and query.isdigit() and query != "0"):
|
||||
suggestion_dict = suggestion_session[context['from_user_id']]
|
||||
if (suggestion_dict != None):
|
||||
query = suggestion_dict[int(query)-1]
|
||||
if (query == None):
|
||||
return "输入的序号不在建议列表范围中"
|
||||
else:
|
||||
query = "在上面的基础上,"+query
|
||||
query = self.get_quick_ask_query(query, context)
|
||||
|
||||
user_session[context['from_user_id']] = bot
|
||||
log.info("[NewBing] query={}".format(query))
|
||||
if(self.jailbreak):
|
||||
task = bot.ask(query, conversation_style=self.style,message_id=bot.user_message_id)
|
||||
if (self.jailbreak):
|
||||
task = bot.ask(query, conversation_style=self.style,
|
||||
message_id=bot.user_message_id)
|
||||
else:
|
||||
task = bot.ask(query, conversation_style=self.style)
|
||||
|
||||
try:
|
||||
answer = asyncio.run(task)
|
||||
except Exception as e:
|
||||
bot.pop_last_conversation()
|
||||
log.exception(answer)
|
||||
return f"AI生成内容被微软内容过滤器拦截,已删除最后一次提问的记忆,请尝试使用其他文字描述问题,若AI依然无法正常回复,请使用{clear_memory_commands[0]}命令清除全部记忆"
|
||||
# 最新一条回复
|
||||
|
||||
answer = asyncio.run(task)
|
||||
if isinstance(answer, str):
|
||||
return answer
|
||||
try:
|
||||
reply = answer["item"]["messages"][-1]
|
||||
except Exception as e:
|
||||
self.reset_chat(context['from_user_id'])
|
||||
log.exception(answer)
|
||||
user_session.get(context['from_user_id'], None).reset()
|
||||
log.warn(answer)
|
||||
return "本轮对话已超时,已开启新的一轮对话,请重新提问。"
|
||||
reply_text = reply["text"]
|
||||
reference = ""
|
||||
if "sourceAttributions" in reply:
|
||||
for i, attribution in enumerate(reply["sourceAttributions"]):
|
||||
display_name = attribution["providerDisplayName"]
|
||||
url = attribution["seeMoreUrl"]
|
||||
reference += f"{i+1}、[{display_name}]({url})\n\n"
|
||||
|
||||
if len(reference) > 0:
|
||||
reference = "***\n"+reference
|
||||
|
||||
suggestion = ""
|
||||
if "suggestedResponses" in reply:
|
||||
suggestion_dict = dict()
|
||||
for i, attribution in enumerate(reply["suggestedResponses"]):
|
||||
suggestion_dict[i] = attribution["text"]
|
||||
suggestion += f">{i+1}、{attribution['text']}\n\n"
|
||||
suggestion_session[context['from_user_id']
|
||||
] = suggestion_dict
|
||||
|
||||
if len(suggestion) > 0:
|
||||
suggestion = "***\n你可以通过输入序号快速追问我以下建议问题:\n\n"+suggestion
|
||||
|
||||
throttling = answer["item"]["throttling"]
|
||||
throttling_str = ""
|
||||
|
||||
if throttling["numUserMessagesInConversation"] == throttling["maxNumUserMessagesInConversation"]:
|
||||
self.reset_chat(context['from_user_id'])
|
||||
throttling_str = "(对话轮次已达上限,本次聊天已结束,将开启新的对话)"
|
||||
else:
|
||||
throttling_str = f"对话轮次: {throttling['numUserMessagesInConversation']}/{throttling['maxNumUserMessagesInConversation']}\n"
|
||||
|
||||
response = f"{reply_text}\n{reference}\n{suggestion}\n***\n{throttling_str}"
|
||||
log.info("[NewBing] reply={}", response)
|
||||
user_session[context['from_user_id']] = bot
|
||||
return response
|
||||
else:
|
||||
self.reset_chat(context['from_user_id'])
|
||||
log.warn("[NewBing] reply={}", answer)
|
||||
return "对话被接口拒绝,已开启新的一轮对话。"
|
||||
return self.build_source_attributions(answer, context)
|
||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||
if functions.contain_chinese(query):
|
||||
return "ImageGen目前仅支持使用英文关键词生成图片"
|
||||
@@ -118,8 +116,58 @@ class BingModel(Model):
|
||||
log.info("[NewBing] image_list={}".format(img_list))
|
||||
return img_list
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.warn(e)
|
||||
return "输入的内容可能违反微软的图片生成内容策略。过多的策略冲突可能会导致你被暂停访问。"
|
||||
|
||||
def reset_chat(self, from_user_id):
|
||||
asyncio.run(user_session.get(from_user_id, None).reset())
|
||||
def get_quick_ask_query(self, query, context):
|
||||
if (len(query) == 1 and query.isdigit() and query != "0"):
|
||||
suggestion_dict = suggestion_session[context['from_user_id']]
|
||||
if (suggestion_dict != None):
|
||||
query = suggestion_dict[int(query)-1]
|
||||
if (query == None):
|
||||
return "输入的序号不在建议列表范围中"
|
||||
else:
|
||||
query = "在上面的基础上,"+query
|
||||
return query
|
||||
|
||||
def build_source_attributions(self, answer, context):
|
||||
reference = ""
|
||||
reply = answer["item"]["messages"][-1]
|
||||
reply_text = reply["text"]
|
||||
if "sourceAttributions" in reply:
|
||||
for i, attribution in enumerate(reply["sourceAttributions"]):
|
||||
display_name = attribution["providerDisplayName"]
|
||||
url = attribution["seeMoreUrl"]
|
||||
reference += f"{i+1}、[{display_name}]({url})\n\n"
|
||||
|
||||
if len(reference) > 0:
|
||||
reference = "***\n"+reference
|
||||
|
||||
suggestion = ""
|
||||
if "suggestedResponses" in reply:
|
||||
suggestion_dict = dict()
|
||||
for i, attribution in enumerate(reply["suggestedResponses"]):
|
||||
suggestion_dict[i] = attribution["text"]
|
||||
suggestion += f">{i+1}、{attribution['text']}\n\n"
|
||||
suggestion_session[context['from_user_id']
|
||||
] = suggestion_dict
|
||||
|
||||
if len(suggestion) > 0:
|
||||
suggestion = "***\n你可以通过输入序号快速追问我以下建议问题:\n\n"+suggestion
|
||||
|
||||
throttling = answer["item"]["throttling"]
|
||||
throttling_str = ""
|
||||
|
||||
if throttling["numUserMessagesInConversation"] == throttling["maxNumUserMessagesInConversation"]:
|
||||
user_session.get(context['from_user_id'], None).reset()
|
||||
throttling_str = "(对话轮次已达上限,本次聊天已结束,将开启新的对话)"
|
||||
else:
|
||||
throttling_str = f"对话轮次: {throttling['numUserMessagesInConversation']}/{throttling['maxNumUserMessagesInConversation']}\n"
|
||||
|
||||
response = f"{reply_text}\n{reference}\n{suggestion}\n***\n{throttling_str}"
|
||||
log.info("[NewBing] reply={}", response)
|
||||
return response
|
||||
else:
|
||||
user_session.get(context['from_user_id'], None).reset()
|
||||
log.warn("[NewBing] reply={}", answer)
|
||||
return "对话被接口拒绝,已开启新的一轮对话。"
|
||||
|
||||
@@ -83,20 +83,32 @@ class ChatGPTModel(Model):
|
||||
return "请再问我一次吧"
|
||||
|
||||
|
||||
def reply_text_stream(self, query, new_query, user_id, retry_count=0):
|
||||
async def reply_text_stream(self, query, context, retry_count=0):
|
||||
try:
|
||||
res = openai.Completion.create(
|
||||
model="text-davinci-003", # 对话模型的名称
|
||||
prompt=new_query,
|
||||
user_id=context['from_user_id']
|
||||
new_query = Session.build_session_query(query, user_id)
|
||||
res = openai.ChatCompletion.create(
|
||||
model= model_conf(const.OPEN_AI).get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
messages=new_query,
|
||||
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
#max_tokens=4096, # 回复最大的字符数
|
||||
top_p=1,
|
||||
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
stop=["\n\n\n"],
|
||||
stream=True
|
||||
)
|
||||
return self._process_reply_stream(query, res, user_id)
|
||||
full_response = ""
|
||||
for chunk in res:
|
||||
log.debug(chunk)
|
||||
if (chunk["choices"][0]["finish_reason"]=="stop"):
|
||||
break
|
||||
chunk_message = chunk['choices'][0]['delta'].get("content")
|
||||
if(chunk_message):
|
||||
full_response+=chunk_message
|
||||
yield False,full_response
|
||||
Session.save_session(query, full_response, user_id)
|
||||
log.info("[chatgpt]: reply={}", full_response)
|
||||
yield True,full_response
|
||||
|
||||
except openai.error.RateLimitError as e:
|
||||
# rate limit exception
|
||||
@@ -104,45 +116,22 @@ class ChatGPTModel(Model):
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
log.warn("[CHATGPT] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text_stream(query, user_id, retry_count+1)
|
||||
yield True, self.reply_text_stream(query, user_id, retry_count+1)
|
||||
else:
|
||||
return "提问太快啦,请休息一下再问我吧"
|
||||
yield True, "提问太快啦,请休息一下再问我吧"
|
||||
except openai.error.APIConnectionError as e:
|
||||
log.warn(e)
|
||||
log.warn("[CHATGPT] APIConnection failed")
|
||||
return "我连接不到网络,请稍后重试"
|
||||
yield True, "我连接不到网络,请稍后重试"
|
||||
except openai.error.Timeout as e:
|
||||
log.warn(e)
|
||||
log.warn("[CHATGPT] Timeout")
|
||||
return "我没有收到消息,请稍后重试"
|
||||
yield True, "我没有收到消息,请稍后重试"
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
log.exception(e)
|
||||
Session.clear_session(user_id)
|
||||
return "请再问我一次吧"
|
||||
|
||||
|
||||
def _process_reply_stream(
|
||||
self,
|
||||
query: str,
|
||||
reply: dict,
|
||||
user_id: str
|
||||
) -> str:
|
||||
full_response = ""
|
||||
for response in reply:
|
||||
if response.get("choices") is None or len(response["choices"]) == 0:
|
||||
raise Exception("OpenAI API returned no choices")
|
||||
if response["choices"][0].get("finish_details") is not None:
|
||||
break
|
||||
if response["choices"][0].get("text") is None:
|
||||
raise Exception("OpenAI API returned no text")
|
||||
if response["choices"][0]["text"] == "<|endoftext|>":
|
||||
break
|
||||
yield response["choices"][0]["text"]
|
||||
full_response += response["choices"][0]["text"]
|
||||
if query and full_response:
|
||||
Session.save_session(query, full_response, user_id)
|
||||
|
||||
yield True, "请再问我一次吧"
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
|
||||
@@ -13,7 +13,9 @@ user_session = dict()
|
||||
class OpenAIModel(Model):
|
||||
def __init__(self):
|
||||
openai.api_key = model_conf(const.OPEN_AI).get('api_key')
|
||||
|
||||
proxy = model_conf(const.OPEN_AI).get('proxy')
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
@@ -72,36 +74,55 @@ class OpenAIModel(Model):
|
||||
return "请再问我一次吧"
|
||||
|
||||
|
||||
def reply_text_stream(self, query, new_query, user_id, retry_count=0):
|
||||
async def reply_text_stream(self, query, context, retry_count=0):
|
||||
try:
|
||||
user_id=context['from_user_id']
|
||||
new_query = Session.build_session_query(query, user_id)
|
||||
res = openai.Completion.create(
|
||||
model="text-davinci-003", # 对话模型的名称
|
||||
model= "text-davinci-003", # 对话模型的名称
|
||||
prompt=new_query,
|
||||
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
max_tokens=1200, # 回复最大的字符数
|
||||
#max_tokens=4096, # 回复最大的字符数
|
||||
top_p=1,
|
||||
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
stop=["\n\n\n"],
|
||||
stream=True
|
||||
)
|
||||
return self._process_reply_stream(query, res, user_id)
|
||||
full_response = ""
|
||||
for chunk in res:
|
||||
log.debug(chunk)
|
||||
if (chunk["choices"][0]["finish_reason"]=="stop"):
|
||||
break
|
||||
chunk_message = chunk['choices'][0].get("text")
|
||||
if(chunk_message):
|
||||
full_response+=chunk_message
|
||||
yield False,full_response
|
||||
Session.save_session(query, full_response, user_id)
|
||||
log.info("[chatgpt]: reply={}", full_response)
|
||||
yield True,full_response
|
||||
|
||||
except openai.error.RateLimitError as e:
|
||||
# rate limit exception
|
||||
log.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
log.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, user_id, retry_count+1)
|
||||
log.warn("[CHATGPT] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
yield True, self.reply_text_stream(query, user_id, retry_count+1)
|
||||
else:
|
||||
return "提问太快啦,请休息一下再问我吧"
|
||||
yield True, "提问太快啦,请休息一下再问我吧"
|
||||
except openai.error.APIConnectionError as e:
|
||||
log.warn(e)
|
||||
log.warn("[CHATGPT] APIConnection failed")
|
||||
yield True, "我连接不到网络,请稍后重试"
|
||||
except openai.error.Timeout as e:
|
||||
log.warn(e)
|
||||
log.warn("[CHATGPT] Timeout")
|
||||
yield True, "我没有收到消息,请稍后重试"
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
log.exception(e)
|
||||
Session.clear_session(user_id)
|
||||
return "请再问我一次吧"
|
||||
|
||||
yield True, "请再问我一次吧"
|
||||
|
||||
def _process_reply_stream(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user