init: add wechat and wechat_mp channel

This commit is contained in:
zhayujie
2023-02-13 22:43:38 +08:00
commit b13627ad46
18 changed files with 664 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
.DS_Store
.idea
__pycache__/
venv*
*.pyc
config.json
QR.png

19
LICENSE Normal file
View File

@@ -0,0 +1,19 @@
Copyright (c) 2023 zhayujie
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

62
README.md Normal file
View File

@@ -0,0 +1,62 @@
# 简介
**AI模型** 接入各类 **消息应用**,开发者通过轻量配置即可在二者之间选择一条连线,运行起一个智能对话机器人,在一个项目中轻松完成多条链路的切换。该架构扩展性强,每接入一个应用可复用已有的算法能力,同样每接入一个模型也可作用于所有应用之上。
**模型:**
- [x] ChatGPT
- ...
**应用:**
- [ ] 终端
- [ ] Web
- [x] 个人微信
- [x] 公众号
- [ ] 企业微信
- [ ] Telegram
- [ ] QQ
- [ ] 钉钉
- ...
# 快速开始
## 一、准备
### 1.运行环境
支持 Linux、MacOS、Windows 系统Linux服务器上可长期运行)。同时需安装 Python建议Python版本在 3.7.1~3.10 之间。
项目代码克隆:
```bash
git clone https://github.com/zhayujie/bot-on-anything
cd bot-on-anything/
```
> 或在 Realase 直接手动下载源码。
### 2.配置说明
核心配置文件为 `config.json`
## 二、选择模型
### 1.ChatGPT
## 三、选择应用
### 1.微信
### 2.公众号
## 四、运行

21
app.py Normal file
View File

@@ -0,0 +1,21 @@
# encoding:utf-8
import config
from channel import channel_factory
from common.log import logger
if __name__ == '__main__':
try:
# load config
config.load_config()
logger.info("[INIT] load config: {}".format(config.conf()))
# create channel
channel = channel_factory.create_channel(config.conf().get("channel"))
# startup channel
channel.startup()
except Exception as e:
logger.error("App startup failed!")
logger.exception(e)

9
bridge/bridge.py Normal file
View File

@@ -0,0 +1,9 @@
from model import model_factory
import config
class Bridge(object):
def __init__(self):
pass
def fetch_reply_content(self, query, context):
return model_factory.create_bot(config.conf().get("model")).reply(query, context)

31
channel/channel.py Normal file
View File

@@ -0,0 +1,31 @@
"""
Message sending channel abstract class
"""
from bridge.bridge import Bridge
class Channel(object):
def startup(self):
"""
init channel
"""
raise NotImplementedError
def handle(self, msg):
"""
process received msg
:param msg: message object
"""
raise NotImplementedError
def send(self, msg, receiver):
"""
send message to user
:param msg: message content
:param receiver: receiver channel account
:return:
"""
raise NotImplementedError
def build_reply_content(self, query, context=None):
return Bridge().fetch_reply_content(query, context)

View File

@@ -0,0 +1,21 @@
"""
channel factory
"""
from common import const
def create_channel(channel_type):
"""
create a channel instance
:param channel_type: channel type code
:return: channel instance
"""
if channel_type == const.WECHAT:
from channel.wechat.wechat_channel import WechatChannel
return WechatChannel()
elif channel_type == const.WECHAT_MP:
from channel.wechat.wechat_mp_channel import WechatPublicAccount
return WechatPublicAccount()
else:
raise RuntimeError

View File

View File

@@ -0,0 +1,165 @@
# encoding:utf-8
"""
wechat channel
"""
import itchat
import json
from itchat.content import *
from channel.channel import Channel
from concurrent.futures import ThreadPoolExecutor
from common.log import logger
from config import conf
import requests
import io
thread_pool = ThreadPoolExecutor(max_workers=8)
@itchat.msg_register(TEXT)
def handler_single_msg(msg):
WechatChannel().handle(msg)
return None
@itchat.msg_register(TEXT, isGroupChat=True)
def handler_group_msg(msg):
WechatChannel().handle_group(msg)
return None
class WechatChannel(Channel):
def __init__(self):
pass
def startup(self):
# login by scan QRCode
itchat.auto_login(enableCmdQR=2)
# start message listener
itchat.run()
def handle(self, msg):
logger.debug("[WX]receive msg: " + json.dumps(msg, ensure_ascii=False))
from_user_id = msg['FromUserName']
to_user_id = msg['ToUserName'] # 接收人id
other_user_id = msg['User']['UserName'] # 对手方id
content = msg['Text']
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
if from_user_id == other_user_id and match_prefix is not None:
# 好友向自己发送消息
if match_prefix != '':
str_list = content.split(match_prefix, 1)
if len(str_list) == 2:
content = str_list[1].strip()
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
thread_pool.submit(self._do_send_img, content, from_user_id)
else:
thread_pool.submit(self._do_send, content, from_user_id)
elif to_user_id == other_user_id and match_prefix:
# 自己给好友发送消息
str_list = content.split(match_prefix, 1)
if len(str_list) == 2:
content = str_list[1].strip()
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
thread_pool.submit(self._do_send_img, content, to_user_id)
else:
thread_pool.submit(self._do_send, content, to_user_id)
def handle_group(self, msg):
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
group_name = msg['User'].get('NickName', None)
group_id = msg['User'].get('UserName', None)
if not group_name:
return ""
origin_content = msg['Content']
content = msg['Content']
content_list = content.split(' ', 1)
context_special_list = content.split('\u2005', 1)
if len(context_special_list) == 2:
content = context_special_list[1]
elif len(content_list) == 2:
content = content_list[1]
config = conf()
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \
or self.check_contain(origin_content, config.get('group_chat_keyword'))
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
thread_pool.submit(self._do_send_img, content, group_id)
else:
thread_pool.submit(self._do_send_group, content, msg)
def send(self, msg, receiver):
logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver))
itchat.send(msg, toUserName=receiver)
def _do_send(self, query, reply_user_id):
try:
if not query:
return
context = dict()
context['from_user_id'] = reply_user_id
reply_text = super().build_reply_content(query, context)
if reply_text:
self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
except Exception as e:
logger.exception(e)
def _do_send_img(self, query, reply_user_id):
try:
if not query:
return
context = dict()
context['type'] = 'IMAGE_CREATE'
img_url = super().build_reply_content(query, context)
if not img_url:
return
# 图片下载
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
# 图片发送
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
itchat.send_image(image_storage, reply_user_id)
except Exception as e:
logger.exception(e)
def _do_send_group(self, query, msg):
if not query:
return
context = dict()
context['from_user_id'] = msg['ActualUserName']
reply_text = super().build_reply_content(query, context)
if reply_text:
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
self.send(conf().get("group_chat_reply_prefix", "") + reply_text, msg['User']['UserName'])
def check_prefix(self, content, prefix_list):
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None
def check_contain(self, content, keyword_list):
if not keyword_list:
return None
for ky in keyword_list:
if content.find(ky) != -1:
return True
return None

View File

@@ -0,0 +1,58 @@
import werobot
import time
import config
from common import const
from common.log import logger
from channel.channel import Channel
from concurrent.futures import ThreadPoolExecutor
robot = werobot.WeRoBot(token=config.fetch(const.WECHAT_MP).get('token'))
thread_pool = ThreadPoolExecutor(max_workers=8)
cache = {}
@robot.text
def hello_world(msg):
logger.info('[WX_Public] receive public msg: {}, userId: {}'.format(msg.content, msg.source))
key = msg.content + '|' + msg.source
if cache.get(key):
cache.get(key)['req_times'] += 1
return WechatPublicAccount().handle(msg)
class WechatPublicAccount(Channel):
def startup(self):
logger.info('[WX_Public] Wechat Public account service start!')
robot.config['PORT'] = config.fetch(const.WECHAT_MP).get('port')
robot.run()
def handle(self, msg, count=0):
context = dict()
context['from_user_id'] = msg.source
key = msg.content + '|' + msg.source
res = cache.get(key)
if not res:
thread_pool.submit(self._do_send, msg.content, context)
temp = {'flag': True, 'req_times': 1}
cache[key] = temp
if count < 10:
time.sleep(2)
return self.handle(msg, count+1)
elif res.get('flag', False) and res.get('data', None):
cache.pop(key)
return res['data']
elif res.get('flag', False) and not res.get('data', None):
if res.get('req_times') == 3 and count == 9:
return '不好意思我的CPU烧了请再问我一次吧~'
if count < 10:
time.sleep(0.5)
return self.handle(msg, count+1)
return "请再说一次"
def _do_send(self, query, context):
reply_text = super().build_reply_content(query, context)
logger.info('[WX_Public] reply content: {}'.format(reply_text))
key = query + '|' + context['from_user_id']
cache[key] = {'flag': True, 'data': reply_text}

6
common/const.py Normal file
View File

@@ -0,0 +1,6 @@
# channel
WECHAT = "wechat"
WECHAT_MP = "wechat_mp"
# model
OPEN_AI = "openai"

18
common/log.py Normal file
View File

@@ -0,0 +1,18 @@
# encoding:utf-8
import logging
import sys
def _get_logger():
log = logging.getLogger('log')
log.setLevel(logging.INFO)
console_handle = logging.StreamHandler(sys.stdout)
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'))
log.addHandler(console_handle)
return log
# 日志句柄
logger = _get_logger()

23
config-template.json Normal file
View File

@@ -0,0 +1,23 @@
{
"channel": "wechat",
"bot": "openai",
"openai": {
"api_key": "YOUR API KEY",
"conversation_max_tokens": 1000,
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"
},
"wechat": {
"single_chat_prefix": ["bot", "@bot"],
"single_chat_reply_prefix": "[bot] ",
"group_chat_prefix": ["@bot"],
"group_name_white_list": ["ALL_GROUP"],
"image_create_prefix": ["画", "看", "找一张"]
},
"wechat_mp": {
"token": "YOUR TOKEN",
"port": "8088"
}
}

32
config.py Normal file
View File

@@ -0,0 +1,32 @@
# encoding:utf-8
import json
import os
config = {}
def load_config():
global config
config_path = "config.json"
if not os.path.exists(config_path):
raise Exception('配置文件不存在请根据config-template.json模板创建config.json文件')
config_str = read_file(config_path)
# 将json字符串反序列化为dict类型
config = json.loads(config_str)
def get_root():
return os.path.dirname(os.path.abspath( __file__ ))
def read_file(path):
with open(path, mode='r', encoding='utf-8') as f:
return f.read()
def conf():
return config
def fetch(model):
return config.get(model)

13
model/model.py Normal file
View File

@@ -0,0 +1,13 @@
"""
Auto-replay chat robot abstract class
"""
class Model(object):
def reply(self, query, context=None):
"""
model auto-reply content
:param req: received message
:return: reply content
"""
raise NotImplementedError

19
model/model_factory.py Normal file
View File

@@ -0,0 +1,19 @@
"""
channel factory
"""
from common import const
def create_bot(model_type):
"""
create a channel instance
:param channel_type: channel type code
:return: channel instance
"""
if model_type == const.OPEN_AI:
# OpenAI 官方对话模型API
from model.openai.open_ai_model import OpenAIModel
return OpenAIModel()
raise RuntimeError

BIN
model/openai/.DS_Store vendored Normal file

Binary file not shown.

View File

@@ -0,0 +1,160 @@
# encoding:utf-8
from model.model import Model
from config import fetch
from common import const
from common.log import logger
import openai
import time
user_session = dict()
# OpenAI对话模型API (可用)
class OpenAIModel(Model):
def __init__(self):
openai.api_key = fetch(const.OPEN_AI).get('api_key')
def reply(self, query, context=None):
# acquire reply content
if not context or not context.get('type') or context.get('type') == 'TEXT':
logger.info("[OPEN_AI] query={}".format(query))
from_user_id = context['from_user_id']
if query == '#清除记忆':
Session.clear_session(from_user_id)
return '记忆已清除'
new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query))
reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content and query:
Session.save_session(query, reply_content, from_user_id)
return reply_content
elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query, 0)
def reply_text(self, query, user_id, retry_count=0):
try:
response = openai.Completion.create(
model="text-davinci-003", # 对话模型的名称
prompt=query,
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
max_tokens=1200, # 回复最大的字符数
top_p=1,
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
stop=["\n\n\n"]
)
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
logger.info("[OPEN_AI] reply={}".format(res_content))
return res_content
except openai.error.RateLimitError as e:
# rate limit exception
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, user_id, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e:
# unknown exception
logger.exception(e)
Session.clear_session(user_id)
return "请再问我一次吧"
def create_img(self, query, retry_count=0):
try:
logger.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create(
prompt=query, #图片描述
n=1, #每次生成图片的数量
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
return image_url
except openai.error.RateLimitError as e:
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e:
logger.exception(e)
return None
class Session(object):
@staticmethod
def build_session_query(query, user_id):
'''
build query with conversation history
e.g. Q: xxx
A: xxx
Q: xxx
:param query: query content
:param user_id: from user id
:return: query content with conversaction
'''
prompt = fetch(const.OPEN_AI).get("character_desc", "")
if prompt:
prompt += "<|endoftext|>\n\n\n"
session = user_session.get(user_id, None)
if session:
for conversation in session:
prompt += "Q: " + conversation["question"] + "\n\n\nA: " + conversation["answer"] + "<|endoftext|>\n"
prompt += "Q: " + query + "\nA: "
return prompt
else:
return prompt + "Q: " + query + "\nA: "
@staticmethod
def save_session(query, answer, user_id):
max_tokens = fetch(const.OPEN_AI).get("conversation_max_tokens")
if not max_tokens:
# default 3000
max_tokens = 1000
conversation = dict()
conversation["question"] = query
conversation["answer"] = answer
session = user_session.get(user_id)
logger.debug(conversation)
logger.debug(session)
if session:
# append conversation
session.append(conversation)
else:
# create session
queue = list()
queue.append(conversation)
user_session[user_id] = queue
# discard exceed limit conversation
Session.discard_exceed_conversation(user_session[user_id], max_tokens)
@staticmethod
def discard_exceed_conversation(session, max_tokens):
count = 0
count_list = list()
for i in range(len(session)-1, -1, -1):
# count tokens of conversation list
history_conv = session[i]
count += len(history_conv["question"]) + len(history_conv["answer"])
count_list.append(count)
for c in count_list:
if c > max_tokens:
# pop first conversation
session.pop(0)
@staticmethod
def clear_session(user_id):
user_session[user_id] = []