mirror of
https://github.com/zhayujie/bot-on-anything.git
synced 2026-02-08 17:52:02 +08:00
4
app.py
4
app.py
@@ -8,8 +8,9 @@ from multiprocessing import Pool
|
||||
|
||||
# 启动通道
|
||||
def start_process(channel_type):
|
||||
# 若为多进程启动,子进程无法直接访问主进程的内存空间,重新创建config类
|
||||
config.load_config()
|
||||
model_type = config.conf().get("model").get("type")
|
||||
# load config
|
||||
log.info("[INIT] Start up: {} on {}", model_type, channel_type)
|
||||
|
||||
# create channel
|
||||
@@ -18,7 +19,6 @@ def start_process(channel_type):
|
||||
# startup channel
|
||||
channel.startup()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
# load config
|
||||
|
||||
@@ -5,8 +5,11 @@ from channel.http import auth
|
||||
from flask import Flask, request, render_template, make_response
|
||||
from datetime import timedelta
|
||||
from common import const
|
||||
from common import functions
|
||||
from config import channel_conf
|
||||
from config import channel_conf_val
|
||||
from channel.channel import Channel
|
||||
|
||||
http_app = Flask(__name__,)
|
||||
# 自动重载模板文件
|
||||
http_app.jinja_env.auto_reload = True
|
||||
@@ -38,9 +41,9 @@ def index():
|
||||
|
||||
@http_app.route("/login", methods=['POST', 'GET'])
|
||||
def login():
|
||||
response = make_response("<html></html>",301)
|
||||
response.headers.add_header('content-type','text/plain')
|
||||
response.headers.add_header('location','./')
|
||||
response = make_response("<html></html>", 301)
|
||||
response.headers.add_header('content-type', 'text/plain')
|
||||
response.headers.add_header('location', './')
|
||||
if (auth.identify(request) == True):
|
||||
return response
|
||||
else:
|
||||
@@ -51,16 +54,29 @@ def login():
|
||||
return response
|
||||
else:
|
||||
return render_template('login.html')
|
||||
response.headers.set('location','./login?err=登录失败')
|
||||
response.headers.set('location', './login?err=登录失败')
|
||||
return response
|
||||
|
||||
|
||||
class HttpChannel(Channel):
|
||||
def startup(self):
|
||||
http_app.run(host='0.0.0.0', port=channel_conf(const.HTTP).get('port'))
|
||||
|
||||
def handle(self, data):
|
||||
context = dict()
|
||||
img_match_prefix = functions.check_prefix(
|
||||
data["msg"], channel_conf_val(const.HTTP, 'image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
data["msg"] = data["msg"].split(img_match_prefix, 1)[1].strip()
|
||||
context['type'] = 'IMAGE_CREATE'
|
||||
id = data["id"]
|
||||
context['from_user_id'] = str(id)
|
||||
return super().build_reply_content(data["msg"], context)
|
||||
|
||||
reply = super().build_reply_content(data["msg"], context)
|
||||
if img_match_prefix:
|
||||
if not isinstance(reply, list):
|
||||
return reply
|
||||
images = ""
|
||||
for url in reply:
|
||||
images += f"[]({url})\n"
|
||||
reply = images
|
||||
return reply
|
||||
|
||||
@@ -49,20 +49,23 @@ class TelegramChannel(Channel):
|
||||
return
|
||||
context = dict()
|
||||
context['type'] = 'IMAGE_CREATE'
|
||||
img_url = super().build_reply_content(msg.text, context)
|
||||
if not img_url:
|
||||
img_urls = super().build_reply_content(msg.text, context)
|
||||
if not img_urls:
|
||||
return
|
||||
|
||||
if not isinstance(img_urls, list):
|
||||
bot.reply_to(msg,img_urls)
|
||||
return
|
||||
for url in img_urls:
|
||||
# 图片下载
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
pic_res = requests.get(url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
|
||||
# 图片发送
|
||||
logger.info('[Telegrame] sendImage, receiver={}'.format(reply_user_id))
|
||||
bot.send_photo(msg.chat.id,image_storage)
|
||||
# 图片发送
|
||||
logger.info('[Telegrame] sendImage, receiver={}'.format(reply_user_id))
|
||||
bot.send_photo(msg.chat.id,image_storage)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
@@ -168,20 +168,23 @@ class WechatChannel(Channel):
|
||||
return
|
||||
context = dict()
|
||||
context['type'] = 'IMAGE_CREATE'
|
||||
img_url = super().build_reply_content(query, context)
|
||||
if not img_url:
|
||||
img_urls = super().build_reply_content(query, context)
|
||||
if not img_urls:
|
||||
return
|
||||
|
||||
if not isinstance(img_urls, list):
|
||||
self.send(channel_conf_val(const.WECHAT, "single_chat_reply_prefix") + img_urls, reply_user_id)
|
||||
return
|
||||
for url in img_urls:
|
||||
# 图片下载
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
pic_res = requests.get(url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
|
||||
# 图片发送
|
||||
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
|
||||
itchat.send_image(image_storage, reply_user_id)
|
||||
# 图片发送
|
||||
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
|
||||
itchat.send_image(image_storage, reply_user_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
17
common/functions.py
Normal file
17
common/functions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import re
|
||||
|
||||
|
||||
def contain_chinese(str):
|
||||
"""
|
||||
判断一个字符串中是否含有中文
|
||||
"""
|
||||
pattern = re.compile('[\u4e00-\u9fa5]')
|
||||
match = pattern.search(str)
|
||||
return match != None
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
@@ -52,6 +52,7 @@
|
||||
},
|
||||
|
||||
"http": {
|
||||
"image_create_prefix": ["画", "draw", "Draw"],
|
||||
"http_auth_secret_key": "6d25a684-9558-11e9-aa94-efccd7a0659b",
|
||||
"http_auth_password": "6.67428e-11",
|
||||
"port": "80"
|
||||
|
||||
@@ -4,6 +4,8 @@ from model.model import Model
|
||||
from config import model_conf_val
|
||||
from common import log
|
||||
from EdgeGPT import Chatbot, ConversationStyle
|
||||
from ImageGen import ImageGen
|
||||
from common import functions
|
||||
|
||||
user_session = dict()
|
||||
suggestion_session = dict()
|
||||
@@ -14,71 +16,91 @@ class BingModel(Model):
|
||||
|
||||
style = ConversationStyle.creative
|
||||
bot: Chatbot = None
|
||||
cookies: list = None
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self.bot = Chatbot(cookies=model_conf_val("bing", "cookies"))
|
||||
self.cookies = model_conf_val("bing", "cookies")
|
||||
self.bot = Chatbot(cookies=self.cookies)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
def reply(self, query: str, context=None) -> tuple[str, dict]:
|
||||
bot = user_session.get(context['from_user_id'], None)
|
||||
if (bot == None):
|
||||
bot = self.bot
|
||||
else:
|
||||
if (len(query) == 1 and query.isdigit() and query != "0"):
|
||||
suggestion_dict = suggestion_session[context['from_user_id']]
|
||||
if (suggestion_dict != None):
|
||||
query = suggestion_dict[int(query)-1]
|
||||
if (query == None):
|
||||
return "输入的序号不在建议列表范围中"
|
||||
else:
|
||||
query = "在上面的基础上,"+query
|
||||
log.info("[NewBing] query={}".format(query))
|
||||
task = bot.ask(query, conversation_style=self.style)
|
||||
answer = asyncio.run(task)
|
||||
|
||||
# 最新一条回复
|
||||
reply = answer["item"]["messages"][-1]
|
||||
reply_text = reply["text"]
|
||||
reference = ""
|
||||
if "sourceAttributions" in reply:
|
||||
for i, attribution in enumerate(reply["sourceAttributions"]):
|
||||
display_name = attribution["providerDisplayName"]
|
||||
url = attribution["seeMoreUrl"]
|
||||
reference += f"{i+1}、[{display_name}]({url})\n\n"
|
||||
|
||||
if len(reference) > 0:
|
||||
reference = "***\n"+reference
|
||||
|
||||
suggestion = ""
|
||||
if "suggestedResponses" in reply:
|
||||
suggestion_dict = dict()
|
||||
for i, attribution in enumerate(reply["suggestedResponses"]):
|
||||
suggestion_dict[i] = attribution["text"]
|
||||
suggestion += f">{i+1}、{attribution['text']}\n\n"
|
||||
suggestion_session[context['from_user_id']] = suggestion_dict
|
||||
|
||||
if len(suggestion) > 0:
|
||||
suggestion = "***\n你可以通过输入序号快速追问我以下建议问题:\n\n"+suggestion
|
||||
|
||||
throttling = answer["item"]["throttling"]
|
||||
throttling_str = ""
|
||||
|
||||
if throttling["numUserMessagesInConversation"] == throttling["maxNumUserMessagesInConversation"]:
|
||||
self.reset_chat(context['from_user_id'])
|
||||
throttling_str = "(对话轮次已达上限,本次聊天已结束,将开启新的对话)"
|
||||
if not context or not context.get('type') or context.get('type') == 'TEXT':
|
||||
bot = user_session.get(context['from_user_id'], None)
|
||||
if (bot == None):
|
||||
bot = self.bot
|
||||
else:
|
||||
throttling_str = f"对话轮次: {throttling['numUserMessagesInConversation']}/{throttling['maxNumUserMessagesInConversation']}\n"
|
||||
if (len(query) == 1 and query.isdigit() and query != "0"):
|
||||
suggestion_dict = suggestion_session[context['from_user_id']]
|
||||
if (suggestion_dict != None):
|
||||
query = suggestion_dict[int(query)-1]
|
||||
if (query == None):
|
||||
return "输入的序号不在建议列表范围中"
|
||||
else:
|
||||
query = "在上面的基础上,"+query
|
||||
log.info("[NewBing] query={}".format(query))
|
||||
task = bot.ask(query, conversation_style=self.style)
|
||||
answer = asyncio.run(task)
|
||||
|
||||
response = f"{reply_text}\n{reference}\n{suggestion}\n***\n{throttling_str}"
|
||||
log.info("[NewBing] reply={}", response)
|
||||
user_session[context['from_user_id']] = bot
|
||||
return response
|
||||
else:
|
||||
self.reset_chat(context['from_user_id'])
|
||||
log.warn("[NewBing] reply={}", answer)
|
||||
return "对话被接口拒绝,已开启新的一轮对话。"
|
||||
# 最新一条回复
|
||||
reply = answer["item"]["messages"][-1]
|
||||
reply_text = reply["text"]
|
||||
reference = ""
|
||||
if "sourceAttributions" in reply:
|
||||
for i, attribution in enumerate(reply["sourceAttributions"]):
|
||||
display_name = attribution["providerDisplayName"]
|
||||
url = attribution["seeMoreUrl"]
|
||||
reference += f"{i+1}、[{display_name}]({url})\n\n"
|
||||
|
||||
if len(reference) > 0:
|
||||
reference = "***\n"+reference
|
||||
|
||||
suggestion = ""
|
||||
if "suggestedResponses" in reply:
|
||||
suggestion_dict = dict()
|
||||
for i, attribution in enumerate(reply["suggestedResponses"]):
|
||||
suggestion_dict[i] = attribution["text"]
|
||||
suggestion += f">{i+1}、{attribution['text']}\n\n"
|
||||
suggestion_session[context['from_user_id']
|
||||
] = suggestion_dict
|
||||
|
||||
if len(suggestion) > 0:
|
||||
suggestion = "***\n你可以通过输入序号快速追问我以下建议问题:\n\n"+suggestion
|
||||
|
||||
throttling = answer["item"]["throttling"]
|
||||
throttling_str = ""
|
||||
|
||||
if throttling["numUserMessagesInConversation"] == throttling["maxNumUserMessagesInConversation"]:
|
||||
self.reset_chat(context['from_user_id'])
|
||||
throttling_str = "(对话轮次已达上限,本次聊天已结束,将开启新的对话)"
|
||||
else:
|
||||
throttling_str = f"对话轮次: {throttling['numUserMessagesInConversation']}/{throttling['maxNumUserMessagesInConversation']}\n"
|
||||
|
||||
response = f"{reply_text}\n{reference}\n{suggestion}\n***\n{throttling_str}"
|
||||
log.info("[NewBing] reply={}", response)
|
||||
user_session[context['from_user_id']] = bot
|
||||
return response
|
||||
else:
|
||||
self.reset_chat(context['from_user_id'])
|
||||
log.warn("[NewBing] reply={}", answer)
|
||||
return "对话被接口拒绝,已开启新的一轮对话。"
|
||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||
if functions.contain_chinese(query):
|
||||
return "ImageGen目前仅支持使用英文关键词生成图片"
|
||||
return self.create_img(query)
|
||||
|
||||
def create_img(self, query):
|
||||
try:
|
||||
log.info("[NewBing] image_query={}".format(query))
|
||||
cookie_value = self.cookies[0]["value"]
|
||||
image_generator = ImageGen(cookie_value)
|
||||
img_list = image_generator.get_images(query)
|
||||
log.info("[NewBing] image_list={}".format(img_list))
|
||||
return img_list
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def reset_chat(self, from_user_id):
|
||||
asyncio.run(user_session.get(from_user_id, None).reset())
|
||||
|
||||
@@ -153,7 +153,7 @@ class ChatGPTModel(Model):
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
log.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return image_url
|
||||
return [image_url]
|
||||
except openai.error.RateLimitError as e:
|
||||
log.warn(e)
|
||||
if retry_count < 1:
|
||||
|
||||
@@ -134,7 +134,7 @@ class OpenAIModel(Model):
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
log.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return image_url
|
||||
return [image_url]
|
||||
except openai.error.RateLimitError as e:
|
||||
log.warn(e)
|
||||
if retry_count < 1:
|
||||
|
||||
Reference in New Issue
Block a user