diff --git a/app.py b/app.py index 221a83e..072fc7c 100644 --- a/app.py +++ b/app.py @@ -8,8 +8,9 @@ from multiprocessing import Pool # 启动通道 def start_process(channel_type): + # 若为多进程启动,子进程无法直接访问主进程的内存空间,重新创建config类 + config.load_config() model_type = config.conf().get("model").get("type") - # load config log.info("[INIT] Start up: {} on {}", model_type, channel_type) # create channel @@ -18,7 +19,6 @@ def start_process(channel_type): # startup channel channel.startup() - if __name__ == '__main__': try: # load config diff --git a/channel/http/http_channel.py b/channel/http/http_channel.py index edb5a84..2464f61 100644 --- a/channel/http/http_channel.py +++ b/channel/http/http_channel.py @@ -5,8 +5,11 @@ from channel.http import auth from flask import Flask, request, render_template, make_response from datetime import timedelta from common import const +from common import functions from config import channel_conf +from config import channel_conf_val from channel.channel import Channel + http_app = Flask(__name__,) # 自动重载模板文件 http_app.jinja_env.auto_reload = True @@ -38,9 +41,9 @@ def index(): @http_app.route("/login", methods=['POST', 'GET']) def login(): - response = make_response("",301) - response.headers.add_header('content-type','text/plain') - response.headers.add_header('location','./') + response = make_response("", 301) + response.headers.add_header('content-type', 'text/plain') + response.headers.add_header('location', './') if (auth.identify(request) == True): return response else: @@ -51,16 +54,29 @@ def login(): return response else: return render_template('login.html') - response.headers.set('location','./login?err=登录失败') + response.headers.set('location', './login?err=登录失败') return response + class HttpChannel(Channel): def startup(self): http_app.run(host='0.0.0.0', port=channel_conf(const.HTTP).get('port')) def handle(self, data): context = dict() + img_match_prefix = functions.check_prefix( + data["msg"], channel_conf_val(const.HTTP, 'image_create_prefix')) + if img_match_prefix: + data["msg"] = data["msg"].split(img_match_prefix, 1)[1].strip() + context['type'] = 'IMAGE_CREATE' id = data["id"] context['from_user_id'] = str(id) - return super().build_reply_content(data["msg"], context) - + reply = super().build_reply_content(data["msg"], context) + if img_match_prefix: + if not isinstance(reply, list): + return reply + images = "" + for url in reply: + images += f"[!['IMAGE_CREATE']({url})]({url})\n" + reply = images + return reply diff --git a/channel/telegram/telegram_channel.py b/channel/telegram/telegram_channel.py index 795a409..b7863d6 100644 --- a/channel/telegram/telegram_channel.py +++ b/channel/telegram/telegram_channel.py @@ -49,20 +49,23 @@ class TelegramChannel(Channel): return context = dict() context['type'] = 'IMAGE_CREATE' - img_url = super().build_reply_content(msg.text, context) - if not img_url: + img_urls = super().build_reply_content(msg.text, context) + if not img_urls: return - + if not isinstance(img_urls, list): + bot.reply_to(msg,img_urls) + return + for url in img_urls: # 图片下载 - pic_res = requests.get(img_url, stream=True) - image_storage = io.BytesIO() - for block in pic_res.iter_content(1024): - image_storage.write(block) - image_storage.seek(0) + pic_res = requests.get(url, stream=True) + image_storage = io.BytesIO() + for block in pic_res.iter_content(1024): + image_storage.write(block) + image_storage.seek(0) - # 图片发送 - logger.info('[Telegrame] sendImage, receiver={}'.format(reply_user_id)) - bot.send_photo(msg.chat.id,image_storage) + # 图片发送 + logger.info('[Telegrame] sendImage, receiver={}'.format(reply_user_id)) + bot.send_photo(msg.chat.id,image_storage) except Exception as e: logger.exception(e) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index b5b5023..4277aff 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -168,20 +168,23 @@ class WechatChannel(Channel): return context = dict() context['type'] = 'IMAGE_CREATE' - img_url = super().build_reply_content(query, context) - if not img_url: + img_urls = super().build_reply_content(query, context) + if not img_urls: return - + if not isinstance(img_urls, list): + self.send(channel_conf_val(const.WECHAT, "single_chat_reply_prefix") + img_urls, reply_user_id) + return + for url in img_urls: # 图片下载 - pic_res = requests.get(img_url, stream=True) - image_storage = io.BytesIO() - for block in pic_res.iter_content(1024): - image_storage.write(block) - image_storage.seek(0) + pic_res = requests.get(url, stream=True) + image_storage = io.BytesIO() + for block in pic_res.iter_content(1024): + image_storage.write(block) + image_storage.seek(0) - # 图片发送 - logger.info('[WX] sendImage, receiver={}'.format(reply_user_id)) - itchat.send_image(image_storage, reply_user_id) + # 图片发送 + logger.info('[WX] sendImage, receiver={}'.format(reply_user_id)) + itchat.send_image(image_storage, reply_user_id) except Exception as e: logger.exception(e) diff --git a/common/functions.py b/common/functions.py new file mode 100644 index 0000000..5560890 --- /dev/null +++ b/common/functions.py @@ -0,0 +1,17 @@ +import re + + +def contain_chinese(str): + """ + 判断一个字符串中是否含有中文 + """ + pattern = re.compile('[\u4e00-\u9fa5]') + match = pattern.search(str) + return match != None + + +def check_prefix(content, prefix_list): + for prefix in prefix_list: + if content.startswith(prefix): + return prefix + return None \ No newline at end of file diff --git a/config-template.json b/config-template.json index 24c2a08..481b905 100644 --- a/config-template.json +++ b/config-template.json @@ -52,6 +52,7 @@ }, "http": { + "image_create_prefix": ["画", "draw", "Draw"], "http_auth_secret_key": "6d25a684-9558-11e9-aa94-efccd7a0659b", "http_auth_password": "6.67428e-11", "port": "80" diff --git a/model/bing/new_bing_model.py b/model/bing/new_bing_model.py index 975a8d9..c062158 100644 --- a/model/bing/new_bing_model.py +++ b/model/bing/new_bing_model.py @@ -4,6 +4,8 @@ from model.model import Model from config import model_conf_val from common import log from EdgeGPT import Chatbot, ConversationStyle +from ImageGen import ImageGen +from common import functions user_session = dict() suggestion_session = dict() @@ -14,71 +16,91 @@ class BingModel(Model): style = ConversationStyle.creative bot: Chatbot = None + cookies: list = None def __init__(self): try: - self.bot = Chatbot(cookies=model_conf_val("bing", "cookies")) + self.cookies = model_conf_val("bing", "cookies") + self.bot = Chatbot(cookies=self.cookies) except Exception as e: log.exception(e) def reply(self, query: str, context=None) -> tuple[str, dict]: - bot = user_session.get(context['from_user_id'], None) - if (bot == None): - bot = self.bot - else: - if (len(query) == 1 and query.isdigit() and query != "0"): - suggestion_dict = suggestion_session[context['from_user_id']] - if (suggestion_dict != None): - query = suggestion_dict[int(query)-1] - if (query == None): - return "输入的序号不在建议列表范围中" - else: - query = "在上面的基础上,"+query - log.info("[NewBing] query={}".format(query)) - task = bot.ask(query, conversation_style=self.style) - answer = asyncio.run(task) - - # 最新一条回复 - reply = answer["item"]["messages"][-1] - reply_text = reply["text"] - reference = "" - if "sourceAttributions" in reply: - for i, attribution in enumerate(reply["sourceAttributions"]): - display_name = attribution["providerDisplayName"] - url = attribution["seeMoreUrl"] - reference += f"{i+1}、[{display_name}]({url})\n\n" - - if len(reference) > 0: - reference = "***\n"+reference - - suggestion = "" - if "suggestedResponses" in reply: - suggestion_dict = dict() - for i, attribution in enumerate(reply["suggestedResponses"]): - suggestion_dict[i] = attribution["text"] - suggestion += f">{i+1}、{attribution['text']}\n\n" - suggestion_session[context['from_user_id']] = suggestion_dict - - if len(suggestion) > 0: - suggestion = "***\n你可以通过输入序号快速追问我以下建议问题:\n\n"+suggestion - - throttling = answer["item"]["throttling"] - throttling_str = "" - - if throttling["numUserMessagesInConversation"] == throttling["maxNumUserMessagesInConversation"]: - self.reset_chat(context['from_user_id']) - throttling_str = "(对话轮次已达上限,本次聊天已结束,将开启新的对话)" + if not context or not context.get('type') or context.get('type') == 'TEXT': + bot = user_session.get(context['from_user_id'], None) + if (bot == None): + bot = self.bot else: - throttling_str = f"对话轮次: {throttling['numUserMessagesInConversation']}/{throttling['maxNumUserMessagesInConversation']}\n" + if (len(query) == 1 and query.isdigit() and query != "0"): + suggestion_dict = suggestion_session[context['from_user_id']] + if (suggestion_dict != None): + query = suggestion_dict[int(query)-1] + if (query == None): + return "输入的序号不在建议列表范围中" + else: + query = "在上面的基础上,"+query + log.info("[NewBing] query={}".format(query)) + task = bot.ask(query, conversation_style=self.style) + answer = asyncio.run(task) - response = f"{reply_text}\n{reference}\n{suggestion}\n***\n{throttling_str}" - log.info("[NewBing] reply={}", response) - user_session[context['from_user_id']] = bot - return response - else: - self.reset_chat(context['from_user_id']) - log.warn("[NewBing] reply={}", answer) - return "对话被接口拒绝,已开启新的一轮对话。" + # 最新一条回复 + reply = answer["item"]["messages"][-1] + reply_text = reply["text"] + reference = "" + if "sourceAttributions" in reply: + for i, attribution in enumerate(reply["sourceAttributions"]): + display_name = attribution["providerDisplayName"] + url = attribution["seeMoreUrl"] + reference += f"{i+1}、[{display_name}]({url})\n\n" + + if len(reference) > 0: + reference = "***\n"+reference + + suggestion = "" + if "suggestedResponses" in reply: + suggestion_dict = dict() + for i, attribution in enumerate(reply["suggestedResponses"]): + suggestion_dict[i] = attribution["text"] + suggestion += f">{i+1}、{attribution['text']}\n\n" + suggestion_session[context['from_user_id'] + ] = suggestion_dict + + if len(suggestion) > 0: + suggestion = "***\n你可以通过输入序号快速追问我以下建议问题:\n\n"+suggestion + + throttling = answer["item"]["throttling"] + throttling_str = "" + + if throttling["numUserMessagesInConversation"] == throttling["maxNumUserMessagesInConversation"]: + self.reset_chat(context['from_user_id']) + throttling_str = "(对话轮次已达上限,本次聊天已结束,将开启新的对话)" + else: + throttling_str = f"对话轮次: {throttling['numUserMessagesInConversation']}/{throttling['maxNumUserMessagesInConversation']}\n" + + response = f"{reply_text}\n{reference}\n{suggestion}\n***\n{throttling_str}" + log.info("[NewBing] reply={}", response) + user_session[context['from_user_id']] = bot + return response + else: + self.reset_chat(context['from_user_id']) + log.warn("[NewBing] reply={}", answer) + return "对话被接口拒绝,已开启新的一轮对话。" + elif context.get('type', None) == 'IMAGE_CREATE': + if functions.contain_chinese(query): + return "ImageGen目前仅支持使用英文关键词生成图片" + return self.create_img(query) + + def create_img(self, query): + try: + log.info("[NewBing] image_query={}".format(query)) + cookie_value = self.cookies[0]["value"] + image_generator = ImageGen(cookie_value) + img_list = image_generator.get_images(query) + log.info("[NewBing] image_list={}".format(img_list)) + return img_list + except Exception as e: + log.exception(e) + return None def reset_chat(self, from_user_id): asyncio.run(user_session.get(from_user_id, None).reset()) diff --git a/model/openai/chatgpt_model.py b/model/openai/chatgpt_model.py index 609e9f3..2bae160 100644 --- a/model/openai/chatgpt_model.py +++ b/model/openai/chatgpt_model.py @@ -153,7 +153,7 @@ class ChatGPTModel(Model): ) image_url = response['data'][0]['url'] log.info("[OPEN_AI] image_url={}".format(image_url)) - return image_url + return [image_url] except openai.error.RateLimitError as e: log.warn(e) if retry_count < 1: diff --git a/model/openai/open_ai_model.py b/model/openai/open_ai_model.py index 36dc3c4..e9819dc 100644 --- a/model/openai/open_ai_model.py +++ b/model/openai/open_ai_model.py @@ -134,7 +134,7 @@ class OpenAIModel(Model): ) image_url = response['data'][0]['url'] log.info("[OPEN_AI] image_url={}".format(image_url)) - return image_url + return [image_url] except openai.error.RateLimitError as e: log.warn(e) if retry_count < 1: