mirror of
https://github.com/zhayujie/bot-on-anything.git
synced 2026-01-19 01:21:06 +08:00
init: add wechat and wechat_mp channel
This commit is contained in:
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
.DS_Store
|
||||
.idea
|
||||
__pycache__/
|
||||
venv*
|
||||
*.pyc
|
||||
config.json
|
||||
QR.png
|
||||
19
LICENSE
Normal file
19
LICENSE
Normal file
@@ -0,0 +1,19 @@
|
||||
Copyright (c) 2023 zhayujie
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
62
README.md
Normal file
62
README.md
Normal file
@@ -0,0 +1,62 @@
|
||||
# 简介
|
||||
|
||||
将 **AI模型** 接入各类 **消息应用**,开发者通过轻量配置即可在二者之间选择一条连线,运行起一个智能对话机器人,在一个项目中轻松完成多条链路的切换。该架构扩展性强,每接入一个应用可复用已有的算法能力,同样每接入一个模型也可作用于所有应用之上。
|
||||
|
||||
**模型:**
|
||||
|
||||
- [x] ChatGPT
|
||||
- ...
|
||||
|
||||
**应用:**
|
||||
|
||||
- [ ] 终端
|
||||
- [ ] Web
|
||||
- [x] 个人微信
|
||||
- [x] 公众号
|
||||
- [ ] 企业微信
|
||||
- [ ] Telegram
|
||||
- [ ] QQ
|
||||
- [ ] 钉钉
|
||||
- ...
|
||||
|
||||
|
||||
|
||||
# 快速开始
|
||||
|
||||
## 一、准备
|
||||
|
||||
### 1.运行环境
|
||||
|
||||
支持 Linux、MacOS、Windows 系统(Linux服务器上可长期运行)。同时需安装 Python,建议Python版本在 3.7.1~3.10 之间。
|
||||
|
||||
项目代码克隆:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/zhayujie/bot-on-anything
|
||||
cd bot-on-anything/
|
||||
```
|
||||
> 或在 Realase 直接手动下载源码。
|
||||
|
||||
### 2.配置说明
|
||||
|
||||
核心配置文件为 `config.json`,
|
||||
|
||||
|
||||
|
||||
## 二、选择模型
|
||||
|
||||
### 1.ChatGPT
|
||||
|
||||
|
||||
## 三、选择应用
|
||||
|
||||
### 1.微信
|
||||
|
||||
|
||||
### 2.公众号
|
||||
|
||||
|
||||
|
||||
|
||||
## 四、运行
|
||||
|
||||
21
app.py
Normal file
21
app.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import config
|
||||
from channel import channel_factory
|
||||
from common.log import logger
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
# load config
|
||||
config.load_config()
|
||||
logger.info("[INIT] load config: {}".format(config.conf()))
|
||||
|
||||
# create channel
|
||||
channel = channel_factory.create_channel(config.conf().get("channel"))
|
||||
|
||||
# startup channel
|
||||
channel.startup()
|
||||
except Exception as e:
|
||||
logger.error("App startup failed!")
|
||||
logger.exception(e)
|
||||
9
bridge/bridge.py
Normal file
9
bridge/bridge.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from model import model_factory
|
||||
import config
|
||||
|
||||
class Bridge(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def fetch_reply_content(self, query, context):
|
||||
return model_factory.create_bot(config.conf().get("model")).reply(query, context)
|
||||
31
channel/channel.py
Normal file
31
channel/channel.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Message sending channel abstract class
|
||||
"""
|
||||
|
||||
from bridge.bridge import Bridge
|
||||
|
||||
class Channel(object):
|
||||
def startup(self):
|
||||
"""
|
||||
init channel
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def handle(self, msg):
|
||||
"""
|
||||
process received msg
|
||||
:param msg: message object
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def send(self, msg, receiver):
|
||||
"""
|
||||
send message to user
|
||||
:param msg: message content
|
||||
:param receiver: receiver channel account
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_reply_content(self, query, context=None):
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
21
channel/channel_factory.py
Normal file
21
channel/channel_factory.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
channel factory
|
||||
"""
|
||||
from common import const
|
||||
|
||||
def create_channel(channel_type):
|
||||
"""
|
||||
create a channel instance
|
||||
:param channel_type: channel type code
|
||||
:return: channel instance
|
||||
"""
|
||||
if channel_type == const.WECHAT:
|
||||
from channel.wechat.wechat_channel import WechatChannel
|
||||
return WechatChannel()
|
||||
|
||||
elif channel_type == const.WECHAT_MP:
|
||||
from channel.wechat.wechat_mp_channel import WechatPublicAccount
|
||||
return WechatPublicAccount()
|
||||
|
||||
else:
|
||||
raise RuntimeError
|
||||
0
channel/terminal/terminal_channel.py
Normal file
0
channel/terminal/terminal_channel.py
Normal file
165
channel/wechat/wechat_channel.py
Normal file
165
channel/wechat/wechat_channel.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# encoding:utf-8
|
||||
|
||||
"""
|
||||
wechat channel
|
||||
"""
|
||||
import itchat
|
||||
import json
|
||||
from itchat.content import *
|
||||
from channel.channel import Channel
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
import requests
|
||||
import io
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
|
||||
@itchat.msg_register(TEXT)
|
||||
def handler_single_msg(msg):
|
||||
WechatChannel().handle(msg)
|
||||
return None
|
||||
|
||||
|
||||
@itchat.msg_register(TEXT, isGroupChat=True)
|
||||
def handler_group_msg(msg):
|
||||
WechatChannel().handle_group(msg)
|
||||
return None
|
||||
|
||||
|
||||
class WechatChannel(Channel):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def startup(self):
|
||||
# login by scan QRCode
|
||||
itchat.auto_login(enableCmdQR=2)
|
||||
|
||||
# start message listener
|
||||
itchat.run()
|
||||
|
||||
def handle(self, msg):
|
||||
logger.debug("[WX]receive msg: " + json.dumps(msg, ensure_ascii=False))
|
||||
from_user_id = msg['FromUserName']
|
||||
to_user_id = msg['ToUserName'] # 接收人id
|
||||
other_user_id = msg['User']['UserName'] # 对手方id
|
||||
content = msg['Text']
|
||||
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
|
||||
if from_user_id == other_user_id and match_prefix is not None:
|
||||
# 好友向自己发送消息
|
||||
if match_prefix != '':
|
||||
str_list = content.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
content = str_list[1].strip()
|
||||
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
thread_pool.submit(self._do_send_img, content, from_user_id)
|
||||
else:
|
||||
thread_pool.submit(self._do_send, content, from_user_id)
|
||||
|
||||
elif to_user_id == other_user_id and match_prefix:
|
||||
# 自己给好友发送消息
|
||||
str_list = content.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
content = str_list[1].strip()
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
thread_pool.submit(self._do_send_img, content, to_user_id)
|
||||
else:
|
||||
thread_pool.submit(self._do_send, content, to_user_id)
|
||||
|
||||
|
||||
def handle_group(self, msg):
|
||||
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
|
||||
group_name = msg['User'].get('NickName', None)
|
||||
group_id = msg['User'].get('UserName', None)
|
||||
if not group_name:
|
||||
return ""
|
||||
origin_content = msg['Content']
|
||||
content = msg['Content']
|
||||
content_list = content.split(' ', 1)
|
||||
context_special_list = content.split('\u2005', 1)
|
||||
if len(context_special_list) == 2:
|
||||
content = context_special_list[1]
|
||||
elif len(content_list) == 2:
|
||||
content = content_list[1]
|
||||
|
||||
config = conf()
|
||||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \
|
||||
or self.check_contain(origin_content, config.get('group_chat_keyword'))
|
||||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
thread_pool.submit(self._do_send_img, content, group_id)
|
||||
else:
|
||||
thread_pool.submit(self._do_send_group, content, msg)
|
||||
|
||||
def send(self, msg, receiver):
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver))
|
||||
itchat.send(msg, toUserName=receiver)
|
||||
|
||||
def _do_send(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = dict()
|
||||
context['from_user_id'] = reply_user_id
|
||||
reply_text = super().build_reply_content(query, context)
|
||||
if reply_text:
|
||||
self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def _do_send_img(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = dict()
|
||||
context['type'] = 'IMAGE_CREATE'
|
||||
img_url = super().build_reply_content(query, context)
|
||||
if not img_url:
|
||||
return
|
||||
|
||||
# 图片下载
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
|
||||
# 图片发送
|
||||
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
|
||||
itchat.send_image(image_storage, reply_user_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def _do_send_group(self, query, msg):
|
||||
if not query:
|
||||
return
|
||||
context = dict()
|
||||
context['from_user_id'] = msg['ActualUserName']
|
||||
reply_text = super().build_reply_content(query, context)
|
||||
if reply_text:
|
||||
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
|
||||
self.send(conf().get("group_chat_reply_prefix", "") + reply_text, msg['User']['UserName'])
|
||||
|
||||
|
||||
def check_prefix(self, content, prefix_list):
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
|
||||
def check_contain(self, content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
58
channel/wechat/wechat_mp_channel.py
Normal file
58
channel/wechat/wechat_mp_channel.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import werobot
|
||||
import time
|
||||
import config
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from channel.channel import Channel
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
robot = werobot.WeRoBot(token=config.fetch(const.WECHAT_MP).get('token'))
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
cache = {}
|
||||
|
||||
@robot.text
|
||||
def hello_world(msg):
|
||||
logger.info('[WX_Public] receive public msg: {}, userId: {}'.format(msg.content, msg.source))
|
||||
key = msg.content + '|' + msg.source
|
||||
if cache.get(key):
|
||||
cache.get(key)['req_times'] += 1
|
||||
return WechatPublicAccount().handle(msg)
|
||||
|
||||
|
||||
class WechatPublicAccount(Channel):
|
||||
def startup(self):
|
||||
logger.info('[WX_Public] Wechat Public account service start!')
|
||||
robot.config['PORT'] = config.fetch(const.WECHAT_MP).get('port')
|
||||
robot.run()
|
||||
|
||||
def handle(self, msg, count=0):
|
||||
context = dict()
|
||||
context['from_user_id'] = msg.source
|
||||
key = msg.content + '|' + msg.source
|
||||
res = cache.get(key)
|
||||
if not res:
|
||||
thread_pool.submit(self._do_send, msg.content, context)
|
||||
temp = {'flag': True, 'req_times': 1}
|
||||
cache[key] = temp
|
||||
if count < 10:
|
||||
time.sleep(2)
|
||||
return self.handle(msg, count+1)
|
||||
|
||||
elif res.get('flag', False) and res.get('data', None):
|
||||
cache.pop(key)
|
||||
return res['data']
|
||||
|
||||
elif res.get('flag', False) and not res.get('data', None):
|
||||
if res.get('req_times') == 3 and count == 9:
|
||||
return '不好意思我的CPU烧了,请再问我一次吧~'
|
||||
if count < 10:
|
||||
time.sleep(0.5)
|
||||
return self.handle(msg, count+1)
|
||||
return "请再说一次"
|
||||
|
||||
|
||||
def _do_send(self, query, context):
|
||||
reply_text = super().build_reply_content(query, context)
|
||||
logger.info('[WX_Public] reply content: {}'.format(reply_text))
|
||||
key = query + '|' + context['from_user_id']
|
||||
cache[key] = {'flag': True, 'data': reply_text}
|
||||
6
common/const.py
Normal file
6
common/const.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# channel
|
||||
WECHAT = "wechat"
|
||||
WECHAT_MP = "wechat_mp"
|
||||
|
||||
# model
|
||||
OPEN_AI = "openai"
|
||||
18
common/log.py
Normal file
18
common/log.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
def _get_logger():
|
||||
log = logging.getLogger('log')
|
||||
log.setLevel(logging.INFO)
|
||||
console_handle = logging.StreamHandler(sys.stdout)
|
||||
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'))
|
||||
log.addHandler(console_handle)
|
||||
return log
|
||||
|
||||
|
||||
# 日志句柄
|
||||
logger = _get_logger()
|
||||
23
config-template.json
Normal file
23
config-template.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"channel": "wechat",
|
||||
"bot": "openai",
|
||||
|
||||
"openai": {
|
||||
"api_key": "YOUR API KEY",
|
||||
"conversation_max_tokens": 1000,
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"
|
||||
},
|
||||
|
||||
"wechat": {
|
||||
"single_chat_prefix": ["bot", "@bot"],
|
||||
"single_chat_reply_prefix": "[bot] ",
|
||||
"group_chat_prefix": ["@bot"],
|
||||
"group_name_white_list": ["ALL_GROUP"],
|
||||
"image_create_prefix": ["画", "看", "找一张"]
|
||||
},
|
||||
|
||||
"wechat_mp": {
|
||||
"token": "YOUR TOKEN",
|
||||
"port": "8088"
|
||||
}
|
||||
}
|
||||
32
config.py
Normal file
32
config.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
config = {}
|
||||
|
||||
|
||||
def load_config():
|
||||
global config
|
||||
config_path = "config.json"
|
||||
if not os.path.exists(config_path):
|
||||
raise Exception('配置文件不存在,请根据config-template.json模板创建config.json文件')
|
||||
|
||||
config_str = read_file(config_path)
|
||||
# 将json字符串反序列化为dict类型
|
||||
config = json.loads(config_str)
|
||||
|
||||
def get_root():
|
||||
return os.path.dirname(os.path.abspath( __file__ ))
|
||||
|
||||
|
||||
def read_file(path):
|
||||
with open(path, mode='r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def conf():
|
||||
return config
|
||||
|
||||
def fetch(model):
|
||||
return config.get(model)
|
||||
13
model/model.py
Normal file
13
model/model.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Auto-replay chat robot abstract class
|
||||
"""
|
||||
|
||||
|
||||
class Model(object):
|
||||
def reply(self, query, context=None):
|
||||
"""
|
||||
model auto-reply content
|
||||
:param req: received message
|
||||
:return: reply content
|
||||
"""
|
||||
raise NotImplementedError
|
||||
19
model/model_factory.py
Normal file
19
model/model_factory.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
channel factory
|
||||
"""
|
||||
|
||||
from common import const
|
||||
|
||||
def create_bot(model_type):
|
||||
"""
|
||||
create a channel instance
|
||||
:param channel_type: channel type code
|
||||
:return: channel instance
|
||||
"""
|
||||
|
||||
if model_type == const.OPEN_AI:
|
||||
# OpenAI 官方对话模型API
|
||||
from model.openai.open_ai_model import OpenAIModel
|
||||
return OpenAIModel()
|
||||
|
||||
raise RuntimeError
|
||||
BIN
model/openai/.DS_Store
vendored
Normal file
BIN
model/openai/.DS_Store
vendored
Normal file
Binary file not shown.
160
model/openai/open_ai_model.py
Normal file
160
model/openai/open_ai_model.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from model.model import Model
|
||||
from config import fetch
|
||||
from common import const
|
||||
from common.log import logger
|
||||
import openai
|
||||
import time
|
||||
|
||||
user_session = dict()
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class OpenAIModel(Model):
|
||||
def __init__(self):
|
||||
openai.api_key = fetch(const.OPEN_AI).get('api_key')
|
||||
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if not context or not context.get('type') or context.get('type') == 'TEXT':
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
from_user_id = context['from_user_id']
|
||||
if query == '#清除记忆':
|
||||
Session.clear_session(from_user_id)
|
||||
return '记忆已清除'
|
||||
|
||||
new_query = Session.build_session_query(query, from_user_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
||||
|
||||
reply_content = self.reply_text(new_query, from_user_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
||||
if reply_content and query:
|
||||
Session.save_session(query, reply_content, from_user_id)
|
||||
return reply_content
|
||||
|
||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||
return self.create_img(query, 0)
|
||||
|
||||
def reply_text(self, query, user_id, retry_count=0):
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
model="text-davinci-003", # 对话模型的名称
|
||||
prompt=query,
|
||||
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
max_tokens=1200, # 回复最大的字符数
|
||||
top_p=1,
|
||||
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
stop=["\n\n\n"]
|
||||
)
|
||||
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
||||
logger.info("[OPEN_AI] reply={}".format(res_content))
|
||||
return res_content
|
||||
except openai.error.RateLimitError as e:
|
||||
# rate limit exception
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, user_id, retry_count+1)
|
||||
else:
|
||||
return "提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
logger.exception(e)
|
||||
Session.clear_session(user_id)
|
||||
return "请再问我一次吧"
|
||||
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
prompt=query, #图片描述
|
||||
n=1, #每次生成图片的数量
|
||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, retry_count+1)
|
||||
else:
|
||||
return "提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return None
|
||||
|
||||
|
||||
class Session(object):
|
||||
@staticmethod
|
||||
def build_session_query(query, user_id):
|
||||
'''
|
||||
build query with conversation history
|
||||
e.g. Q: xxx
|
||||
A: xxx
|
||||
Q: xxx
|
||||
:param query: query content
|
||||
:param user_id: from user id
|
||||
:return: query content with conversaction
|
||||
'''
|
||||
prompt = fetch(const.OPEN_AI).get("character_desc", "")
|
||||
if prompt:
|
||||
prompt += "<|endoftext|>\n\n\n"
|
||||
session = user_session.get(user_id, None)
|
||||
if session:
|
||||
for conversation in session:
|
||||
prompt += "Q: " + conversation["question"] + "\n\n\nA: " + conversation["answer"] + "<|endoftext|>\n"
|
||||
prompt += "Q: " + query + "\nA: "
|
||||
return prompt
|
||||
else:
|
||||
return prompt + "Q: " + query + "\nA: "
|
||||
|
||||
@staticmethod
|
||||
def save_session(query, answer, user_id):
|
||||
max_tokens = fetch(const.OPEN_AI).get("conversation_max_tokens")
|
||||
if not max_tokens:
|
||||
# default 3000
|
||||
max_tokens = 1000
|
||||
conversation = dict()
|
||||
conversation["question"] = query
|
||||
conversation["answer"] = answer
|
||||
session = user_session.get(user_id)
|
||||
logger.debug(conversation)
|
||||
logger.debug(session)
|
||||
if session:
|
||||
# append conversation
|
||||
session.append(conversation)
|
||||
else:
|
||||
# create session
|
||||
queue = list()
|
||||
queue.append(conversation)
|
||||
user_session[user_id] = queue
|
||||
|
||||
# discard exceed limit conversation
|
||||
Session.discard_exceed_conversation(user_session[user_id], max_tokens)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def discard_exceed_conversation(session, max_tokens):
|
||||
count = 0
|
||||
count_list = list()
|
||||
for i in range(len(session)-1, -1, -1):
|
||||
# count tokens of conversation list
|
||||
history_conv = session[i]
|
||||
count += len(history_conv["question"]) + len(history_conv["answer"])
|
||||
count_list.append(count)
|
||||
|
||||
for c in count_list:
|
||||
if c > max_tokens:
|
||||
# pop first conversation
|
||||
session.pop(0)
|
||||
|
||||
@staticmethod
|
||||
def clear_session(user_id):
|
||||
user_session[user_id] = []
|
||||
Reference in New Issue
Block a user