mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-01-19 01:21:01 +08:00
formatting code
This commit is contained in:
2
.github/ISSUE_TEMPLATE.md
vendored
2
.github/ISSUE_TEMPLATE.md
vendored
@@ -27,5 +27,5 @@
|
|||||||
### 环境
|
### 环境
|
||||||
|
|
||||||
- 操作系统类型 (Mac/Windows/Linux):
|
- 操作系统类型 (Mac/Windows/Linux):
|
||||||
- Python版本 ( 执行 `python3 -V` ):
|
- Python版本 ( 执行 `python3 -V` ):
|
||||||
- pip版本 ( 依赖问题此项必填,执行 `pip3 -V`):
|
- pip版本 ( 依赖问题此项必填,执行 `pip3 -V`):
|
||||||
|
|||||||
4
.github/workflows/deploy-image.yml
vendored
4
.github/workflows/deploy-image.yml
vendored
@@ -49,9 +49,9 @@ jobs:
|
|||||||
file: ./docker/Dockerfile.latest
|
file: ./docker/Dockerfile.latest
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
|
||||||
- uses: actions/delete-package-versions@v4
|
- uses: actions/delete-package-versions@v4
|
||||||
with:
|
with:
|
||||||
package-name: 'chatgpt-on-wechat'
|
package-name: 'chatgpt-on-wechat'
|
||||||
package-type: 'container'
|
package-type: 'container'
|
||||||
min-versions-to-keep: 10
|
min-versions-to-keep: 10
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -120,7 +120,7 @@ pip3 install azure-cognitiveservices-speech
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# config.json文件内容示例
|
# config.json文件内容示例
|
||||||
{
|
{
|
||||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
|
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
|
||||||
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||||
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
|
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
|
||||||
@@ -128,7 +128,7 @@ pip3 install azure-cognitiveservices-speech
|
|||||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||||
"speech_recognition": false, # 是否开启语音识别
|
"speech_recognition": false, # 是否开启语音识别
|
||||||
@@ -160,7 +160,7 @@ pip3 install azure-cognitiveservices-speech
|
|||||||
**4.其他配置**
|
**4.其他配置**
|
||||||
|
|
||||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放)
|
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放)
|
||||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。
|
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。
|
||||||
@@ -181,7 +181,7 @@ pip3 install azure-cognitiveservices-speech
|
|||||||
```bash
|
```bash
|
||||||
python3 app.py
|
python3 app.py
|
||||||
```
|
```
|
||||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||||
|
|
||||||
|
|
||||||
### 2.服务器部署
|
### 2.服务器部署
|
||||||
@@ -189,7 +189,7 @@ python3 app.py
|
|||||||
使用nohup命令在后台运行程序:
|
使用nohup命令在后台运行程序:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
touch nohup.out # 首次运行需要新建日志文件
|
touch nohup.out # 首次运行需要新建日志文件
|
||||||
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
|
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
|
||||||
```
|
```
|
||||||
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。
|
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。
|
||||||
|
|||||||
30
app.py
30
app.py
@@ -1,23 +1,28 @@
|
|||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from config import conf, load_config
|
|
||||||
from channel import channel_factory
|
|
||||||
from common.log import logger
|
|
||||||
from plugins import *
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from channel import channel_factory
|
||||||
|
from common.log import logger
|
||||||
|
from config import conf, load_config
|
||||||
|
from plugins import *
|
||||||
|
|
||||||
|
|
||||||
def sigterm_handler_wrap(_signo):
|
def sigterm_handler_wrap(_signo):
|
||||||
old_handler = signal.getsignal(_signo)
|
old_handler = signal.getsignal(_signo)
|
||||||
|
|
||||||
def func(_signo, _stack_frame):
|
def func(_signo, _stack_frame):
|
||||||
logger.info("signal {} received, exiting...".format(_signo))
|
logger.info("signal {} received, exiting...".format(_signo))
|
||||||
conf().save_user_datas()
|
conf().save_user_datas()
|
||||||
if callable(old_handler): # check old_handler
|
if callable(old_handler): # check old_handler
|
||||||
return old_handler(_signo, _stack_frame)
|
return old_handler(_signo, _stack_frame)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
signal.signal(_signo, func)
|
signal.signal(_signo, func)
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
try:
|
try:
|
||||||
# load config
|
# load config
|
||||||
@@ -28,17 +33,17 @@ def run():
|
|||||||
sigterm_handler_wrap(signal.SIGTERM)
|
sigterm_handler_wrap(signal.SIGTERM)
|
||||||
|
|
||||||
# create channel
|
# create channel
|
||||||
channel_name=conf().get('channel_type', 'wx')
|
channel_name = conf().get("channel_type", "wx")
|
||||||
|
|
||||||
if "--cmd" in sys.argv:
|
if "--cmd" in sys.argv:
|
||||||
channel_name = 'terminal'
|
channel_name = "terminal"
|
||||||
|
|
||||||
if channel_name == 'wxy':
|
if channel_name == "wxy":
|
||||||
os.environ['WECHATY_LOG']="warn"
|
os.environ["WECHATY_LOG"] = "warn"
|
||||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
||||||
|
|
||||||
channel = channel_factory.create_channel(channel_name)
|
channel = channel_factory.create_channel(channel_name)
|
||||||
if channel_name in ['wx','wxy','terminal','wechatmp','wechatmp_service']:
|
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service"]:
|
||||||
PluginManager().load_plugins()
|
PluginManager().load_plugins()
|
||||||
|
|
||||||
# startup channel
|
# startup channel
|
||||||
@@ -47,5 +52,6 @@ def run():
|
|||||||
logger.error("App startup failed!")
|
logger.error("App startup failed!")
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
run()
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from bot.bot import Bot
|
from bot.bot import Bot
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
|
|
||||||
@@ -9,20 +10,35 @@ from bridge.reply import Reply, ReplyType
|
|||||||
class BaiduUnitBot(Bot):
|
class BaiduUnitBot(Bot):
|
||||||
def reply(self, query, context=None):
|
def reply(self, query, context=None):
|
||||||
token = self.get_token()
|
token = self.get_token()
|
||||||
url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token
|
url = (
|
||||||
post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}"
|
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
|
||||||
|
+ token
|
||||||
|
)
|
||||||
|
post_data = (
|
||||||
|
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
|
||||||
|
+ query
|
||||||
|
+ '", "hyper_params": {"chat_custom_bot_profile": 1}}}'
|
||||||
|
)
|
||||||
print(post_data)
|
print(post_data)
|
||||||
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||||
response = requests.post(url, data=post_data.encode(), headers=headers)
|
response = requests.post(url, data=post_data.encode(), headers=headers)
|
||||||
if response:
|
if response:
|
||||||
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
|
reply = Reply(
|
||||||
|
ReplyType.TEXT,
|
||||||
|
response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1],
|
||||||
|
)
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def get_token(self):
|
def get_token(self):
|
||||||
access_key = 'YOUR_ACCESS_KEY'
|
access_key = "YOUR_ACCESS_KEY"
|
||||||
secret_key = 'YOUR_SECRET_KEY'
|
secret_key = "YOUR_SECRET_KEY"
|
||||||
host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key
|
host = (
|
||||||
|
"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="
|
||||||
|
+ access_key
|
||||||
|
+ "&client_secret="
|
||||||
|
+ secret_key
|
||||||
|
)
|
||||||
response = requests.get(host)
|
response = requests.get(host)
|
||||||
if response:
|
if response:
|
||||||
print(response.json())
|
print(response.json())
|
||||||
return response.json()['access_token']
|
return response.json()["access_token"]
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from bridge.reply import Reply
|
|||||||
|
|
||||||
|
|
||||||
class Bot(object):
|
class Bot(object):
|
||||||
def reply(self, query, context : Context =None) -> Reply:
|
def reply(self, query, context: Context = None) -> Reply:
|
||||||
"""
|
"""
|
||||||
bot auto-reply content
|
bot auto-reply content
|
||||||
:param req: received message
|
:param req: received message
|
||||||
|
|||||||
@@ -13,20 +13,24 @@ def create_bot(bot_type):
|
|||||||
if bot_type == const.BAIDU:
|
if bot_type == const.BAIDU:
|
||||||
# Baidu Unit对话接口
|
# Baidu Unit对话接口
|
||||||
from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||||
|
|
||||||
return BaiduUnitBot()
|
return BaiduUnitBot()
|
||||||
|
|
||||||
elif bot_type == const.CHATGPT:
|
elif bot_type == const.CHATGPT:
|
||||||
# ChatGPT 网页端web接口
|
# ChatGPT 网页端web接口
|
||||||
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
|
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
|
||||||
|
|
||||||
return ChatGPTBot()
|
return ChatGPTBot()
|
||||||
|
|
||||||
elif bot_type == const.OPEN_AI:
|
elif bot_type == const.OPEN_AI:
|
||||||
# OpenAI 官方对话模型API
|
# OpenAI 官方对话模型API
|
||||||
from bot.openai.open_ai_bot import OpenAIBot
|
from bot.openai.open_ai_bot import OpenAIBot
|
||||||
|
|
||||||
return OpenAIBot()
|
return OpenAIBot()
|
||||||
|
|
||||||
elif bot_type == const.CHATGPTONAZURE:
|
elif bot_type == const.CHATGPTONAZURE:
|
||||||
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
||||||
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
||||||
|
|
||||||
return AzureChatGPTBot()
|
return AzureChatGPTBot()
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|||||||
@@ -1,42 +1,53 @@
|
|||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import openai.error
|
||||||
|
|
||||||
from bot.bot import Bot
|
from bot.bot import Bot
|
||||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||||
from bot.openai.open_ai_image import OpenAIImage
|
from bot.openai.open_ai_image import OpenAIImage
|
||||||
from bot.session_manager import SessionManager
|
from bot.session_manager import SessionManager
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from config import conf, load_config
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
from common.token_bucket import TokenBucket
|
from common.token_bucket import TokenBucket
|
||||||
import openai
|
from config import conf, load_config
|
||||||
import openai.error
|
|
||||||
import time
|
|
||||||
|
|
||||||
# OpenAI对话模型API (可用)
|
# OpenAI对话模型API (可用)
|
||||||
class ChatGPTBot(Bot,OpenAIImage):
|
class ChatGPTBot(Bot, OpenAIImage):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# set the default api_key
|
# set the default api_key
|
||||||
openai.api_key = conf().get('open_ai_api_key')
|
openai.api_key = conf().get("open_ai_api_key")
|
||||||
if conf().get('open_ai_api_base'):
|
if conf().get("open_ai_api_base"):
|
||||||
openai.api_base = conf().get('open_ai_api_base')
|
openai.api_base = conf().get("open_ai_api_base")
|
||||||
proxy = conf().get('proxy')
|
proxy = conf().get("proxy")
|
||||||
if proxy:
|
if proxy:
|
||||||
openai.proxy = proxy
|
openai.proxy = proxy
|
||||||
if conf().get('rate_limit_chatgpt'):
|
if conf().get("rate_limit_chatgpt"):
|
||||||
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
|
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
||||||
|
|
||||||
self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo")
|
self.sessions = SessionManager(
|
||||||
self.args ={
|
ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo"
|
||||||
|
)
|
||||||
|
self.args = {
|
||||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||||
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||||
# "max_tokens":4096, # 回复最大的字符数
|
# "max_tokens":4096, # 回复最大的字符数
|
||||||
"top_p":1,
|
"top_p": 1,
|
||||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
"frequency_penalty": conf().get(
|
||||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
"frequency_penalty", 0.0
|
||||||
"request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||||
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
|
"presence_penalty": conf().get(
|
||||||
|
"presence_penalty", 0.0
|
||||||
|
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||||
|
"request_timeout": conf().get(
|
||||||
|
"request_timeout", None
|
||||||
|
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||||
|
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||||
}
|
}
|
||||||
|
|
||||||
def reply(self, query, context=None):
|
def reply(self, query, context=None):
|
||||||
@@ -44,39 +55,50 @@ class ChatGPTBot(Bot,OpenAIImage):
|
|||||||
if context.type == ContextType.TEXT:
|
if context.type == ContextType.TEXT:
|
||||||
logger.info("[CHATGPT] query={}".format(query))
|
logger.info("[CHATGPT] query={}".format(query))
|
||||||
|
|
||||||
|
session_id = context["session_id"]
|
||||||
session_id = context['session_id']
|
|
||||||
reply = None
|
reply = None
|
||||||
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
|
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||||
if query in clear_memory_commands:
|
if query in clear_memory_commands:
|
||||||
self.sessions.clear_session(session_id)
|
self.sessions.clear_session(session_id)
|
||||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||||
elif query == '#清除所有':
|
elif query == "#清除所有":
|
||||||
self.sessions.clear_all_session()
|
self.sessions.clear_all_session()
|
||||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||||
elif query == '#更新配置':
|
elif query == "#更新配置":
|
||||||
load_config()
|
load_config()
|
||||||
reply = Reply(ReplyType.INFO, '配置已更新')
|
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||||
if reply:
|
if reply:
|
||||||
return reply
|
return reply
|
||||||
session = self.sessions.session_query(query, session_id)
|
session = self.sessions.session_query(query, session_id)
|
||||||
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
||||||
|
|
||||||
api_key = context.get('openai_api_key')
|
api_key = context.get("openai_api_key")
|
||||||
|
|
||||||
# if context.get('stream'):
|
# if context.get('stream'):
|
||||||
# # reply in stream
|
# # reply in stream
|
||||||
# return self.reply_text_stream(query, new_query, session_id)
|
# return self.reply_text_stream(query, new_query, session_id)
|
||||||
|
|
||||||
reply_content = self.reply_text(session, api_key)
|
reply_content = self.reply_text(session, api_key)
|
||||||
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
|
logger.debug(
|
||||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
|
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
session.messages,
|
||||||
|
session_id,
|
||||||
|
reply_content["content"],
|
||||||
|
reply_content["completion_tokens"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
reply_content["completion_tokens"] == 0
|
||||||
|
and len(reply_content["content"]) > 0
|
||||||
|
):
|
||||||
|
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||||
elif reply_content["completion_tokens"] > 0:
|
elif reply_content["completion_tokens"] > 0:
|
||||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
self.sessions.session_reply(
|
||||||
|
reply_content["content"], session_id, reply_content["total_tokens"]
|
||||||
|
)
|
||||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||||
else:
|
else:
|
||||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||||
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
|
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
@@ -89,53 +111,55 @@ class ChatGPTBot(Bot,OpenAIImage):
|
|||||||
reply = Reply(ReplyType.ERROR, retstring)
|
reply = Reply(ReplyType.ERROR, retstring)
|
||||||
return reply
|
return reply
|
||||||
else:
|
else:
|
||||||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
|
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def reply_text(self, session:ChatGPTSession, api_key=None, retry_count=0) -> dict:
|
def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict:
|
||||||
'''
|
"""
|
||||||
call openai's ChatCompletion to get the answer
|
call openai's ChatCompletion to get the answer
|
||||||
:param session: a conversation session
|
:param session: a conversation session
|
||||||
:param session_id: session id
|
:param session_id: session id
|
||||||
:param retry_count: retry count
|
:param retry_count: retry count
|
||||||
:return: {}
|
:return: {}
|
||||||
'''
|
"""
|
||||||
try:
|
try:
|
||||||
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
|
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
||||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||||
# if api_key == None, the default openai.api_key will be used
|
# if api_key == None, the default openai.api_key will be used
|
||||||
response = openai.ChatCompletion.create(
|
response = openai.ChatCompletion.create(
|
||||||
api_key=api_key, messages=session.messages, **self.args
|
api_key=api_key, messages=session.messages, **self.args
|
||||||
)
|
)
|
||||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||||
return {"total_tokens": response["usage"]["total_tokens"],
|
return {
|
||||||
"completion_tokens": response["usage"]["completion_tokens"],
|
"total_tokens": response["usage"]["total_tokens"],
|
||||||
"content": response.choices[0]['message']['content']}
|
"completion_tokens": response["usage"]["completion_tokens"],
|
||||||
|
"content": response.choices[0]["message"]["content"],
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
need_retry = retry_count < 2
|
need_retry = retry_count < 2
|
||||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||||
if isinstance(e, openai.error.RateLimitError):
|
if isinstance(e, openai.error.RateLimitError):
|
||||||
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
||||||
result['content'] = "提问太快啦,请休息一下再问我吧"
|
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||||
if need_retry:
|
if need_retry:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
elif isinstance(e, openai.error.Timeout):
|
elif isinstance(e, openai.error.Timeout):
|
||||||
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
||||||
result['content'] = "我没有收到你的消息"
|
result["content"] = "我没有收到你的消息"
|
||||||
if need_retry:
|
if need_retry:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
elif isinstance(e, openai.error.APIConnectionError):
|
elif isinstance(e, openai.error.APIConnectionError):
|
||||||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
||||||
need_retry = False
|
need_retry = False
|
||||||
result['content'] = "我连接不到你的网络"
|
result["content"] = "我连接不到你的网络"
|
||||||
else:
|
else:
|
||||||
logger.warn("[CHATGPT] Exception: {}".format(e))
|
logger.warn("[CHATGPT] Exception: {}".format(e))
|
||||||
need_retry = False
|
need_retry = False
|
||||||
self.sessions.clear_session(session.session_id)
|
self.sessions.clear_session(session.session_id)
|
||||||
|
|
||||||
if need_retry:
|
if need_retry:
|
||||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
|
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
|
||||||
return self.reply_text(session, api_key, retry_count+1)
|
return self.reply_text(session, api_key, retry_count + 1)
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -145,4 +169,4 @@ class AzureChatGPTBot(ChatGPTBot):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
openai.api_type = "azure"
|
openai.api_type = "azure"
|
||||||
openai.api_version = "2023-03-15-preview"
|
openai.api_version = "2023-03-15-preview"
|
||||||
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
||||||
|
|||||||
@@ -1,20 +1,23 @@
|
|||||||
from bot.session_manager import Session
|
from bot.session_manager import Session
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
'''
|
|
||||||
|
"""
|
||||||
e.g. [
|
e.g. [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||||
{"role": "user", "content": "Where was it played?"}
|
{"role": "user", "content": "Where was it played?"}
|
||||||
]
|
]
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTSession(Session):
|
class ChatGPTSession(Session):
|
||||||
def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"):
|
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
|
||||||
super().__init__(session_id, system_prompt)
|
super().__init__(session_id, system_prompt)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def discard_exceeding(self, max_tokens, cur_tokens= None):
|
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||||
precise = True
|
precise = True
|
||||||
try:
|
try:
|
||||||
cur_tokens = self.calc_tokens()
|
cur_tokens = self.calc_tokens()
|
||||||
@@ -22,7 +25,9 @@ class ChatGPTSession(Session):
|
|||||||
precise = False
|
precise = False
|
||||||
if cur_tokens is None:
|
if cur_tokens is None:
|
||||||
raise e
|
raise e
|
||||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
logger.debug(
|
||||||
|
"Exception when counting tokens precisely for query: {}".format(e)
|
||||||
|
)
|
||||||
while cur_tokens > max_tokens:
|
while cur_tokens > max_tokens:
|
||||||
if len(self.messages) > 2:
|
if len(self.messages) > 2:
|
||||||
self.messages.pop(1)
|
self.messages.pop(1)
|
||||||
@@ -34,25 +39,32 @@ class ChatGPTSession(Session):
|
|||||||
cur_tokens = cur_tokens - max_tokens
|
cur_tokens = cur_tokens - max_tokens
|
||||||
break
|
break
|
||||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
logger.warn(
|
||||||
|
"user message exceed max_tokens. total_tokens={}".format(cur_tokens)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
logger.debug(
|
||||||
|
"max_tokens={}, total_tokens={}, len(messages)={}".format(
|
||||||
|
max_tokens, cur_tokens, len(self.messages)
|
||||||
|
)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
if precise:
|
if precise:
|
||||||
cur_tokens = self.calc_tokens()
|
cur_tokens = self.calc_tokens()
|
||||||
else:
|
else:
|
||||||
cur_tokens = cur_tokens - max_tokens
|
cur_tokens = cur_tokens - max_tokens
|
||||||
return cur_tokens
|
return cur_tokens
|
||||||
|
|
||||||
def calc_tokens(self):
|
def calc_tokens(self):
|
||||||
return num_tokens_from_messages(self.messages, self.model)
|
return num_tokens_from_messages(self.messages, self.model)
|
||||||
|
|
||||||
|
|
||||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
def num_tokens_from_messages(messages, model):
|
def num_tokens_from_messages(messages, model):
|
||||||
"""Returns the number of tokens used by a list of messages."""
|
"""Returns the number of tokens used by a list of messages."""
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@@ -63,13 +75,17 @@ def num_tokens_from_messages(messages, model):
|
|||||||
elif model == "gpt-4":
|
elif model == "gpt-4":
|
||||||
return num_tokens_from_messages(messages, model="gpt-4-0314")
|
return num_tokens_from_messages(messages, model="gpt-4-0314")
|
||||||
elif model == "gpt-3.5-turbo-0301":
|
elif model == "gpt-3.5-turbo-0301":
|
||||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
tokens_per_message = (
|
||||||
|
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
|
)
|
||||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||||
elif model == "gpt-4-0314":
|
elif model == "gpt-4-0314":
|
||||||
tokens_per_message = 3
|
tokens_per_message = 3
|
||||||
tokens_per_name = 1
|
tokens_per_name = 1
|
||||||
else:
|
else:
|
||||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
|
logger.warn(
|
||||||
|
f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301."
|
||||||
|
)
|
||||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
|||||||
@@ -1,41 +1,52 @@
|
|||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import openai.error
|
||||||
|
|
||||||
from bot.bot import Bot
|
from bot.bot import Bot
|
||||||
from bot.openai.open_ai_image import OpenAIImage
|
from bot.openai.open_ai_image import OpenAIImage
|
||||||
from bot.openai.open_ai_session import OpenAISession
|
from bot.openai.open_ai_session import OpenAISession
|
||||||
from bot.session_manager import SessionManager
|
from bot.session_manager import SessionManager
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from config import conf
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
import openai
|
from config import conf
|
||||||
import openai.error
|
|
||||||
import time
|
|
||||||
|
|
||||||
user_session = dict()
|
user_session = dict()
|
||||||
|
|
||||||
|
|
||||||
# OpenAI对话模型API (可用)
|
# OpenAI对话模型API (可用)
|
||||||
class OpenAIBot(Bot, OpenAIImage):
|
class OpenAIBot(Bot, OpenAIImage):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
openai.api_key = conf().get('open_ai_api_key')
|
openai.api_key = conf().get("open_ai_api_key")
|
||||||
if conf().get('open_ai_api_base'):
|
if conf().get("open_ai_api_base"):
|
||||||
openai.api_base = conf().get('open_ai_api_base')
|
openai.api_base = conf().get("open_ai_api_base")
|
||||||
proxy = conf().get('proxy')
|
proxy = conf().get("proxy")
|
||||||
if proxy:
|
if proxy:
|
||||||
openai.proxy = proxy
|
openai.proxy = proxy
|
||||||
|
|
||||||
self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003")
|
self.sessions = SessionManager(
|
||||||
|
OpenAISession, model=conf().get("model") or "text-davinci-003"
|
||||||
|
)
|
||||||
self.args = {
|
self.args = {
|
||||||
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
|
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
|
||||||
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||||
"max_tokens":1200, # 回复最大的字符数
|
"max_tokens": 1200, # 回复最大的字符数
|
||||||
"top_p":1,
|
"top_p": 1,
|
||||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
"frequency_penalty": conf().get(
|
||||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
"frequency_penalty", 0.0
|
||||||
"request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||||
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
|
"presence_penalty": conf().get(
|
||||||
"stop":["\n\n\n"]
|
"presence_penalty", 0.0
|
||||||
|
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||||
|
"request_timeout": conf().get(
|
||||||
|
"request_timeout", None
|
||||||
|
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||||
|
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||||
|
"stop": ["\n\n\n"],
|
||||||
}
|
}
|
||||||
|
|
||||||
def reply(self, query, context=None):
|
def reply(self, query, context=None):
|
||||||
@@ -43,24 +54,34 @@ class OpenAIBot(Bot, OpenAIImage):
|
|||||||
if context and context.type:
|
if context and context.type:
|
||||||
if context.type == ContextType.TEXT:
|
if context.type == ContextType.TEXT:
|
||||||
logger.info("[OPEN_AI] query={}".format(query))
|
logger.info("[OPEN_AI] query={}".format(query))
|
||||||
session_id = context['session_id']
|
session_id = context["session_id"]
|
||||||
reply = None
|
reply = None
|
||||||
if query == '#清除记忆':
|
if query == "#清除记忆":
|
||||||
self.sessions.clear_session(session_id)
|
self.sessions.clear_session(session_id)
|
||||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||||
elif query == '#清除所有':
|
elif query == "#清除所有":
|
||||||
self.sessions.clear_all_session()
|
self.sessions.clear_all_session()
|
||||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||||
else:
|
else:
|
||||||
session = self.sessions.session_query(query, session_id)
|
session = self.sessions.session_query(query, session_id)
|
||||||
result = self.reply_text(session)
|
result = self.reply_text(session)
|
||||||
total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content']
|
total_tokens, completion_tokens, reply_content = (
|
||||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens))
|
result["total_tokens"],
|
||||||
|
result["completion_tokens"],
|
||||||
|
result["content"],
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||||
|
str(session), session_id, reply_content, completion_tokens
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if total_tokens == 0 :
|
if total_tokens == 0:
|
||||||
reply = Reply(ReplyType.ERROR, reply_content)
|
reply = Reply(ReplyType.ERROR, reply_content)
|
||||||
else:
|
else:
|
||||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
self.sessions.session_reply(
|
||||||
|
reply_content, session_id, total_tokens
|
||||||
|
)
|
||||||
reply = Reply(ReplyType.TEXT, reply_content)
|
reply = Reply(ReplyType.TEXT, reply_content)
|
||||||
return reply
|
return reply
|
||||||
elif context.type == ContextType.IMAGE_CREATE:
|
elif context.type == ContextType.IMAGE_CREATE:
|
||||||
@@ -72,42 +93,44 @@ class OpenAIBot(Bot, OpenAIImage):
|
|||||||
reply = Reply(ReplyType.ERROR, retstring)
|
reply = Reply(ReplyType.ERROR, retstring)
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def reply_text(self, session:OpenAISession, retry_count=0):
|
def reply_text(self, session: OpenAISession, retry_count=0):
|
||||||
try:
|
try:
|
||||||
response = openai.Completion.create(
|
response = openai.Completion.create(prompt=str(session), **self.args)
|
||||||
prompt=str(session), **self.args
|
res_content = (
|
||||||
|
response.choices[0]["text"].strip().replace("<|endoftext|>", "")
|
||||||
)
|
)
|
||||||
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
|
||||||
total_tokens = response["usage"]["total_tokens"]
|
total_tokens = response["usage"]["total_tokens"]
|
||||||
completion_tokens = response["usage"]["completion_tokens"]
|
completion_tokens = response["usage"]["completion_tokens"]
|
||||||
logger.info("[OPEN_AI] reply={}".format(res_content))
|
logger.info("[OPEN_AI] reply={}".format(res_content))
|
||||||
return {"total_tokens": total_tokens,
|
return {
|
||||||
"completion_tokens": completion_tokens,
|
"total_tokens": total_tokens,
|
||||||
"content": res_content}
|
"completion_tokens": completion_tokens,
|
||||||
|
"content": res_content,
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
need_retry = retry_count < 2
|
need_retry = retry_count < 2
|
||||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||||
if isinstance(e, openai.error.RateLimitError):
|
if isinstance(e, openai.error.RateLimitError):
|
||||||
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
|
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
|
||||||
result['content'] = "提问太快啦,请休息一下再问我吧"
|
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||||
if need_retry:
|
if need_retry:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
elif isinstance(e, openai.error.Timeout):
|
elif isinstance(e, openai.error.Timeout):
|
||||||
logger.warn("[OPEN_AI] Timeout: {}".format(e))
|
logger.warn("[OPEN_AI] Timeout: {}".format(e))
|
||||||
result['content'] = "我没有收到你的消息"
|
result["content"] = "我没有收到你的消息"
|
||||||
if need_retry:
|
if need_retry:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
elif isinstance(e, openai.error.APIConnectionError):
|
elif isinstance(e, openai.error.APIConnectionError):
|
||||||
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
|
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
|
||||||
need_retry = False
|
need_retry = False
|
||||||
result['content'] = "我连接不到你的网络"
|
result["content"] = "我连接不到你的网络"
|
||||||
else:
|
else:
|
||||||
logger.warn("[OPEN_AI] Exception: {}".format(e))
|
logger.warn("[OPEN_AI] Exception: {}".format(e))
|
||||||
need_retry = False
|
need_retry = False
|
||||||
self.sessions.clear_session(session.session_id)
|
self.sessions.clear_session(session.session_id)
|
||||||
|
|
||||||
if need_retry:
|
if need_retry:
|
||||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
|
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
|
||||||
return self.reply_text(session, retry_count+1)
|
return self.reply_text(session, retry_count + 1)
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -1,38 +1,45 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import openai.error
|
import openai.error
|
||||||
from common.token_bucket import TokenBucket
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from common.token_bucket import TokenBucket
|
||||||
from config import conf
|
from config import conf
|
||||||
|
|
||||||
|
|
||||||
# OPENAI提供的画图接口
|
# OPENAI提供的画图接口
|
||||||
class OpenAIImage(object):
|
class OpenAIImage(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
openai.api_key = conf().get('open_ai_api_key')
|
openai.api_key = conf().get("open_ai_api_key")
|
||||||
if conf().get('rate_limit_dalle'):
|
if conf().get("rate_limit_dalle"):
|
||||||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
|
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
|
||||||
|
|
||||||
def create_img(self, query, retry_count=0):
|
def create_img(self, query, retry_count=0):
|
||||||
try:
|
try:
|
||||||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
|
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
|
||||||
return False, "请求太快了,请休息一下再问我吧"
|
return False, "请求太快了,请休息一下再问我吧"
|
||||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||||
response = openai.Image.create(
|
response = openai.Image.create(
|
||||||
prompt=query, #图片描述
|
prompt=query, # 图片描述
|
||||||
n=1, #每次生成图片的数量
|
n=1, # 每次生成图片的数量
|
||||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
size="256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||||
)
|
)
|
||||||
image_url = response['data'][0]['url']
|
image_url = response["data"][0]["url"]
|
||||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||||
return True, image_url
|
return True, image_url
|
||||||
except openai.error.RateLimitError as e:
|
except openai.error.RateLimitError as e:
|
||||||
logger.warn(e)
|
logger.warn(e)
|
||||||
if retry_count < 1:
|
if retry_count < 1:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
logger.warn(
|
||||||
return self.create_img(query, retry_count+1)
|
"[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(
|
||||||
|
retry_count + 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return self.create_img(query, retry_count + 1)
|
||||||
else:
|
else:
|
||||||
return False, "提问太快啦,请休息一下再问我吧"
|
return False, "提问太快啦,请休息一下再问我吧"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
return False, str(e)
|
return False, str(e)
|
||||||
|
|||||||
@@ -1,32 +1,34 @@
|
|||||||
from bot.session_manager import Session
|
from bot.session_manager import Session
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
|
||||||
|
|
||||||
class OpenAISession(Session):
|
class OpenAISession(Session):
|
||||||
def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"):
|
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
|
||||||
super().__init__(session_id, system_prompt)
|
super().__init__(session_id, system_prompt)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
# 构造对话模型的输入
|
# 构造对话模型的输入
|
||||||
'''
|
"""
|
||||||
e.g. Q: xxx
|
e.g. Q: xxx
|
||||||
A: xxx
|
A: xxx
|
||||||
Q: xxx
|
Q: xxx
|
||||||
'''
|
"""
|
||||||
prompt = ""
|
prompt = ""
|
||||||
for item in self.messages:
|
for item in self.messages:
|
||||||
if item['role'] == 'system':
|
if item["role"] == "system":
|
||||||
prompt += item['content'] + "<|endoftext|>\n\n\n"
|
prompt += item["content"] + "<|endoftext|>\n\n\n"
|
||||||
elif item['role'] == 'user':
|
elif item["role"] == "user":
|
||||||
prompt += "Q: " + item['content'] + "\n"
|
prompt += "Q: " + item["content"] + "\n"
|
||||||
elif item['role'] == 'assistant':
|
elif item["role"] == "assistant":
|
||||||
prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n"
|
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"
|
||||||
|
|
||||||
if len(self.messages) > 0 and self.messages[-1]['role'] == 'user':
|
if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
|
||||||
prompt += "A: "
|
prompt += "A: "
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def discard_exceeding(self, max_tokens, cur_tokens= None):
|
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||||
precise = True
|
precise = True
|
||||||
try:
|
try:
|
||||||
cur_tokens = self.calc_tokens()
|
cur_tokens = self.calc_tokens()
|
||||||
@@ -34,7 +36,9 @@ class OpenAISession(Session):
|
|||||||
precise = False
|
precise = False
|
||||||
if cur_tokens is None:
|
if cur_tokens is None:
|
||||||
raise e
|
raise e
|
||||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
logger.debug(
|
||||||
|
"Exception when counting tokens precisely for query: {}".format(e)
|
||||||
|
)
|
||||||
while cur_tokens > max_tokens:
|
while cur_tokens > max_tokens:
|
||||||
if len(self.messages) > 1:
|
if len(self.messages) > 1:
|
||||||
self.messages.pop(0)
|
self.messages.pop(0)
|
||||||
@@ -46,24 +50,34 @@ class OpenAISession(Session):
|
|||||||
cur_tokens = len(str(self))
|
cur_tokens = len(str(self))
|
||||||
break
|
break
|
||||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
|
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
|
||||||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
|
logger.warn(
|
||||||
|
"user question exceed max_tokens. total_tokens={}".format(
|
||||||
|
cur_tokens
|
||||||
|
)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
logger.debug(
|
||||||
|
"max_tokens={}, total_tokens={}, len(conversation)={}".format(
|
||||||
|
max_tokens, cur_tokens, len(self.messages)
|
||||||
|
)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
if precise:
|
if precise:
|
||||||
cur_tokens = self.calc_tokens()
|
cur_tokens = self.calc_tokens()
|
||||||
else:
|
else:
|
||||||
cur_tokens = len(str(self))
|
cur_tokens = len(str(self))
|
||||||
return cur_tokens
|
return cur_tokens
|
||||||
|
|
||||||
def calc_tokens(self):
|
def calc_tokens(self):
|
||||||
return num_tokens_from_string(str(self), self.model)
|
return num_tokens_from_string(str(self), self.model)
|
||||||
|
|
||||||
|
|
||||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
def num_tokens_from_string(string: str, model: str) -> int:
|
def num_tokens_from_string(string: str, model: str) -> int:
|
||||||
"""Returns the number of tokens in a text string."""
|
"""Returns the number of tokens in a text string."""
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
num_tokens = len(encoding.encode(string,disallowed_special=()))
|
num_tokens = len(encoding.encode(string, disallowed_special=()))
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from common.expired_dict import ExpiredDict
|
|||||||
from common.log import logger
|
from common.log import logger
|
||||||
from config import conf
|
from config import conf
|
||||||
|
|
||||||
|
|
||||||
class Session(object):
|
class Session(object):
|
||||||
def __init__(self, session_id, system_prompt=None):
|
def __init__(self, session_id, system_prompt=None):
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
@@ -13,7 +14,7 @@ class Session(object):
|
|||||||
|
|
||||||
# 重置会话
|
# 重置会话
|
||||||
def reset(self):
|
def reset(self):
|
||||||
system_item = {'role': 'system', 'content': self.system_prompt}
|
system_item = {"role": "system", "content": self.system_prompt}
|
||||||
self.messages = [system_item]
|
self.messages = [system_item]
|
||||||
|
|
||||||
def set_system_prompt(self, system_prompt):
|
def set_system_prompt(self, system_prompt):
|
||||||
@@ -21,13 +22,13 @@ class Session(object):
|
|||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def add_query(self, query):
|
def add_query(self, query):
|
||||||
user_item = {'role': 'user', 'content': query}
|
user_item = {"role": "user", "content": query}
|
||||||
self.messages.append(user_item)
|
self.messages.append(user_item)
|
||||||
|
|
||||||
def add_reply(self, reply):
|
def add_reply(self, reply):
|
||||||
assistant_item = {'role': 'assistant', 'content': reply}
|
assistant_item = {"role": "assistant", "content": reply}
|
||||||
self.messages.append(assistant_item)
|
self.messages.append(assistant_item)
|
||||||
|
|
||||||
def discard_exceeding(self, max_tokens=None, cur_tokens=None):
|
def discard_exceeding(self, max_tokens=None, cur_tokens=None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -37,8 +38,8 @@ class Session(object):
|
|||||||
|
|
||||||
class SessionManager(object):
|
class SessionManager(object):
|
||||||
def __init__(self, sessioncls, **session_args):
|
def __init__(self, sessioncls, **session_args):
|
||||||
if conf().get('expires_in_seconds'):
|
if conf().get("expires_in_seconds"):
|
||||||
sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
sessions = ExpiredDict(conf().get("expires_in_seconds"))
|
||||||
else:
|
else:
|
||||||
sessions = dict()
|
sessions = dict()
|
||||||
self.sessions = sessions
|
self.sessions = sessions
|
||||||
@@ -46,20 +47,22 @@ class SessionManager(object):
|
|||||||
self.session_args = session_args
|
self.session_args = session_args
|
||||||
|
|
||||||
def build_session(self, session_id, system_prompt=None):
|
def build_session(self, session_id, system_prompt=None):
|
||||||
'''
|
"""
|
||||||
如果session_id不在sessions中,创建一个新的session并添加到sessions中
|
如果session_id不在sessions中,创建一个新的session并添加到sessions中
|
||||||
如果system_prompt不会空,会更新session的system_prompt并重置session
|
如果system_prompt不会空,会更新session的system_prompt并重置session
|
||||||
'''
|
"""
|
||||||
if session_id is None:
|
if session_id is None:
|
||||||
return self.sessioncls(session_id, system_prompt, **self.session_args)
|
return self.sessioncls(session_id, system_prompt, **self.session_args)
|
||||||
|
|
||||||
if session_id not in self.sessions:
|
if session_id not in self.sessions:
|
||||||
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
|
self.sessions[session_id] = self.sessioncls(
|
||||||
|
session_id, system_prompt, **self.session_args
|
||||||
|
)
|
||||||
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
|
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
|
||||||
self.sessions[session_id].set_system_prompt(system_prompt)
|
self.sessions[session_id].set_system_prompt(system_prompt)
|
||||||
session = self.sessions[session_id]
|
session = self.sessions[session_id]
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def session_query(self, query, session_id):
|
def session_query(self, query, session_id):
|
||||||
session = self.build_session(session_id)
|
session = self.build_session(session_id)
|
||||||
session.add_query(query)
|
session.add_query(query)
|
||||||
@@ -68,23 +71,33 @@ class SessionManager(object):
|
|||||||
total_tokens = session.discard_exceeding(max_tokens, None)
|
total_tokens = session.discard_exceeding(max_tokens, None)
|
||||||
logger.debug("prompt tokens used={}".format(total_tokens))
|
logger.debug("prompt tokens used={}".format(total_tokens))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
logger.debug(
|
||||||
|
"Exception when counting tokens precisely for prompt: {}".format(str(e))
|
||||||
|
)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def session_reply(self, reply, session_id, total_tokens = None):
|
def session_reply(self, reply, session_id, total_tokens=None):
|
||||||
session = self.build_session(session_id)
|
session = self.build_session(session_id)
|
||||||
session.add_reply(reply)
|
session.add_reply(reply)
|
||||||
try:
|
try:
|
||||||
max_tokens = conf().get("conversation_max_tokens", 1000)
|
max_tokens = conf().get("conversation_max_tokens", 1000)
|
||||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||||
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
|
logger.debug(
|
||||||
|
"raw total_tokens={}, savesession tokens={}".format(
|
||||||
|
total_tokens, tokens_cnt
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
|
logger.debug(
|
||||||
|
"Exception when counting tokens precisely for session: {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def clear_session(self, session_id):
|
def clear_session(self, session_id):
|
||||||
if session_id in self.sessions:
|
if session_id in self.sessions:
|
||||||
del(self.sessions[session_id])
|
del self.sessions[session_id]
|
||||||
|
|
||||||
def clear_all_session(self):
|
def clear_all_session(self):
|
||||||
self.sessions.clear()
|
self.sessions.clear()
|
||||||
|
|||||||
@@ -1,31 +1,31 @@
|
|||||||
|
from bot import bot_factory
|
||||||
from bridge.context import Context
|
from bridge.context import Context
|
||||||
from bridge.reply import Reply
|
from bridge.reply import Reply
|
||||||
from common.log import logger
|
|
||||||
from bot import bot_factory
|
|
||||||
from common.singleton import singleton
|
|
||||||
from voice import voice_factory
|
|
||||||
from config import conf
|
|
||||||
from common import const
|
from common import const
|
||||||
|
from common.log import logger
|
||||||
|
from common.singleton import singleton
|
||||||
|
from config import conf
|
||||||
|
from voice import voice_factory
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class Bridge(object):
|
class Bridge(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.btype={
|
self.btype = {
|
||||||
"chat": const.CHATGPT,
|
"chat": const.CHATGPT,
|
||||||
"voice_to_text": conf().get("voice_to_text", "openai"),
|
"voice_to_text": conf().get("voice_to_text", "openai"),
|
||||||
"text_to_voice": conf().get("text_to_voice", "google")
|
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||||
}
|
}
|
||||||
model_type = conf().get("model")
|
model_type = conf().get("model")
|
||||||
if model_type in ["text-davinci-003"]:
|
if model_type in ["text-davinci-003"]:
|
||||||
self.btype['chat'] = const.OPEN_AI
|
self.btype["chat"] = const.OPEN_AI
|
||||||
if conf().get("use_azure_chatgpt", False):
|
if conf().get("use_azure_chatgpt", False):
|
||||||
self.btype['chat'] = const.CHATGPTONAZURE
|
self.btype["chat"] = const.CHATGPTONAZURE
|
||||||
self.bots={}
|
self.bots = {}
|
||||||
|
|
||||||
def get_bot(self,typename):
|
def get_bot(self, typename):
|
||||||
if self.bots.get(typename) is None:
|
if self.bots.get(typename) is None:
|
||||||
logger.info("create bot {} for {}".format(self.btype[typename],typename))
|
logger.info("create bot {} for {}".format(self.btype[typename], typename))
|
||||||
if typename == "text_to_voice":
|
if typename == "text_to_voice":
|
||||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
||||||
elif typename == "voice_to_text":
|
elif typename == "voice_to_text":
|
||||||
@@ -33,18 +33,15 @@ class Bridge(object):
|
|||||||
elif typename == "chat":
|
elif typename == "chat":
|
||||||
self.bots[typename] = bot_factory.create_bot(self.btype[typename])
|
self.bots[typename] = bot_factory.create_bot(self.btype[typename])
|
||||||
return self.bots[typename]
|
return self.bots[typename]
|
||||||
|
|
||||||
def get_bot_type(self,typename):
|
def get_bot_type(self, typename):
|
||||||
return self.btype[typename]
|
return self.btype[typename]
|
||||||
|
|
||||||
|
def fetch_reply_content(self, query, context: Context) -> Reply:
|
||||||
def fetch_reply_content(self, query, context : Context) -> Reply:
|
|
||||||
return self.get_bot("chat").reply(query, context)
|
return self.get_bot("chat").reply(query, context)
|
||||||
|
|
||||||
|
|
||||||
def fetch_voice_to_text(self, voiceFile) -> Reply:
|
def fetch_voice_to_text(self, voiceFile) -> Reply:
|
||||||
return self.get_bot("voice_to_text").voiceToText(voiceFile)
|
return self.get_bot("voice_to_text").voiceToText(voiceFile)
|
||||||
|
|
||||||
def fetch_text_to_voice(self, text) -> Reply:
|
def fetch_text_to_voice(self, text) -> Reply:
|
||||||
return self.get_bot("text_to_voice").textToVoice(text)
|
return self.get_bot("text_to_voice").textToVoice(text)
|
||||||
|
|
||||||
|
|||||||
@@ -2,36 +2,39 @@
|
|||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
class ContextType (Enum):
|
|
||||||
TEXT = 1 # 文本消息
|
class ContextType(Enum):
|
||||||
VOICE = 2 # 音频消息
|
TEXT = 1 # 文本消息
|
||||||
IMAGE = 3 # 图片消息
|
VOICE = 2 # 音频消息
|
||||||
IMAGE_CREATE = 10 # 创建图片命令
|
IMAGE = 3 # 图片消息
|
||||||
|
IMAGE_CREATE = 10 # 创建图片命令
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
def __init__(self, type : ContextType = None , content = None, kwargs = dict()):
|
def __init__(self, type: ContextType = None, content=None, kwargs=dict()):
|
||||||
self.type = type
|
self.type = type
|
||||||
self.content = content
|
self.content = content
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
if key == 'type':
|
if key == "type":
|
||||||
return self.type is not None
|
return self.type is not None
|
||||||
elif key == 'content':
|
elif key == "content":
|
||||||
return self.content is not None
|
return self.content is not None
|
||||||
else:
|
else:
|
||||||
return key in self.kwargs
|
return key in self.kwargs
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if key == 'type':
|
if key == "type":
|
||||||
return self.type
|
return self.type
|
||||||
elif key == 'content':
|
elif key == "content":
|
||||||
return self.content
|
return self.content
|
||||||
else:
|
else:
|
||||||
return self.kwargs[key]
|
return self.kwargs[key]
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
try:
|
try:
|
||||||
return self[key]
|
return self[key]
|
||||||
@@ -39,20 +42,22 @@ class Context:
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
if key == 'type':
|
if key == "type":
|
||||||
self.type = value
|
self.type = value
|
||||||
elif key == 'content':
|
elif key == "content":
|
||||||
self.content = value
|
self.content = value
|
||||||
else:
|
else:
|
||||||
self.kwargs[key] = value
|
self.kwargs[key] = value
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
if key == 'type':
|
if key == "type":
|
||||||
self.type = None
|
self.type = None
|
||||||
elif key == 'content':
|
elif key == "content":
|
||||||
self.content = None
|
self.content = None
|
||||||
else:
|
else:
|
||||||
del self.kwargs[key]
|
del self.kwargs[key]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
return "Context(type={}, content={}, kwargs={})".format(
|
||||||
|
self.type, self.content, self.kwargs
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,22 +1,25 @@
|
|||||||
|
|
||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class ReplyType(Enum):
|
class ReplyType(Enum):
|
||||||
TEXT = 1 # 文本
|
TEXT = 1 # 文本
|
||||||
VOICE = 2 # 音频文件
|
VOICE = 2 # 音频文件
|
||||||
IMAGE = 3 # 图片文件
|
IMAGE = 3 # 图片文件
|
||||||
IMAGE_URL = 4 # 图片URL
|
IMAGE_URL = 4 # 图片URL
|
||||||
|
|
||||||
INFO = 9
|
INFO = 9
|
||||||
ERROR = 10
|
ERROR = 10
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class Reply:
|
class Reply:
|
||||||
def __init__(self, type : ReplyType = None , content = None):
|
def __init__(self, type: ReplyType = None, content=None):
|
||||||
self.type = type
|
self.type = type
|
||||||
self.content = content
|
self.content = content
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "Reply(type={}, content={})".format(self.type, self.content)
|
return "Reply(type={}, content={})".format(self.type, self.content)
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ from bridge.bridge import Bridge
|
|||||||
from bridge.context import Context
|
from bridge.context import Context
|
||||||
from bridge.reply import *
|
from bridge.reply import *
|
||||||
|
|
||||||
|
|
||||||
class Channel(object):
|
class Channel(object):
|
||||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
||||||
|
|
||||||
def startup(self):
|
def startup(self):
|
||||||
"""
|
"""
|
||||||
init channel
|
init channel
|
||||||
@@ -27,15 +29,15 @@ class Channel(object):
|
|||||||
send message to user
|
send message to user
|
||||||
:param msg: message content
|
:param msg: message content
|
||||||
:param receiver: receiver channel account
|
:param receiver: receiver channel account
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def build_reply_content(self, query, context : Context=None) -> Reply:
|
def build_reply_content(self, query, context: Context = None) -> Reply:
|
||||||
return Bridge().fetch_reply_content(query, context)
|
return Bridge().fetch_reply_content(query, context)
|
||||||
|
|
||||||
def build_voice_to_text(self, voice_file) -> Reply:
|
def build_voice_to_text(self, voice_file) -> Reply:
|
||||||
return Bridge().fetch_voice_to_text(voice_file)
|
return Bridge().fetch_voice_to_text(voice_file)
|
||||||
|
|
||||||
def build_text_to_voice(self, text) -> Reply:
|
def build_text_to_voice(self, text) -> Reply:
|
||||||
return Bridge().fetch_text_to_voice(text)
|
return Bridge().fetch_text_to_voice(text)
|
||||||
|
|||||||
@@ -2,25 +2,31 @@
|
|||||||
channel factory
|
channel factory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def create_channel(channel_type):
|
def create_channel(channel_type):
|
||||||
"""
|
"""
|
||||||
create a channel instance
|
create a channel instance
|
||||||
:param channel_type: channel type code
|
:param channel_type: channel type code
|
||||||
:return: channel instance
|
:return: channel instance
|
||||||
"""
|
"""
|
||||||
if channel_type == 'wx':
|
if channel_type == "wx":
|
||||||
from channel.wechat.wechat_channel import WechatChannel
|
from channel.wechat.wechat_channel import WechatChannel
|
||||||
|
|
||||||
return WechatChannel()
|
return WechatChannel()
|
||||||
elif channel_type == 'wxy':
|
elif channel_type == "wxy":
|
||||||
from channel.wechat.wechaty_channel import WechatyChannel
|
from channel.wechat.wechaty_channel import WechatyChannel
|
||||||
|
|
||||||
return WechatyChannel()
|
return WechatyChannel()
|
||||||
elif channel_type == 'terminal':
|
elif channel_type == "terminal":
|
||||||
from channel.terminal.terminal_channel import TerminalChannel
|
from channel.terminal.terminal_channel import TerminalChannel
|
||||||
|
|
||||||
return TerminalChannel()
|
return TerminalChannel()
|
||||||
elif channel_type == 'wechatmp':
|
elif channel_type == "wechatmp":
|
||||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||||
return WechatMPChannel(passive_reply = True)
|
|
||||||
elif channel_type == 'wechatmp_service':
|
return WechatMPChannel(passive_reply=True)
|
||||||
|
elif channel_type == "wechatmp_service":
|
||||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||||
return WechatMPChannel(passive_reply = False)
|
|
||||||
|
return WechatMPChannel(passive_reply=False)
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|||||||
@@ -1,137 +1,172 @@
|
|||||||
|
|
||||||
|
|
||||||
from asyncio import CancelledError
|
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from common.dequeue import Dequeue
|
from asyncio import CancelledError
|
||||||
from channel.channel import Channel
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from bridge.reply import *
|
|
||||||
from bridge.context import *
|
from bridge.context import *
|
||||||
from config import conf
|
from bridge.reply import *
|
||||||
|
from channel.channel import Channel
|
||||||
|
from common.dequeue import Dequeue
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from config import conf
|
||||||
from plugins import *
|
from plugins import *
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from voice.audio_convert import any_to_wav
|
from voice.audio_convert import any_to_wav
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
|
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
|
||||||
class ChatChannel(Channel):
|
class ChatChannel(Channel):
|
||||||
name = None # 登录的用户名
|
name = None # 登录的用户名
|
||||||
user_id = None # 登录的用户id
|
user_id = None # 登录的用户id
|
||||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||||
lock = threading.Lock() # 用于控制对sessions的访问
|
lock = threading.Lock() # 用于控制对sessions的访问
|
||||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
_thread = threading.Thread(target=self.consume)
|
_thread = threading.Thread(target=self.consume)
|
||||||
_thread.setDaemon(True)
|
_thread.setDaemon(True)
|
||||||
_thread.start()
|
_thread.start()
|
||||||
|
|
||||||
|
|
||||||
# 根据消息构造context,消息内容相关的触发项写在这里
|
# 根据消息构造context,消息内容相关的触发项写在这里
|
||||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||||
context = Context(ctype, content)
|
context = Context(ctype, content)
|
||||||
context.kwargs = kwargs
|
context.kwargs = kwargs
|
||||||
# context首次传入时,origin_ctype是None,
|
# context首次传入时,origin_ctype是None,
|
||||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
||||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
||||||
if 'origin_ctype' not in context:
|
if "origin_ctype" not in context:
|
||||||
context['origin_ctype'] = ctype
|
context["origin_ctype"] = ctype
|
||||||
# context首次传入时,receiver是None,根据类型设置receiver
|
# context首次传入时,receiver是None,根据类型设置receiver
|
||||||
first_in = 'receiver' not in context
|
first_in = "receiver" not in context
|
||||||
# 群名匹配过程,设置session_id和receiver
|
# 群名匹配过程,设置session_id和receiver
|
||||||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
||||||
config = conf()
|
config = conf()
|
||||||
cmsg = context['msg']
|
cmsg = context["msg"]
|
||||||
if context.get("isgroup", False):
|
if context.get("isgroup", False):
|
||||||
group_name = cmsg.other_user_nickname
|
group_name = cmsg.other_user_nickname
|
||||||
group_id = cmsg.other_user_id
|
group_id = cmsg.other_user_id
|
||||||
|
|
||||||
group_name_white_list = config.get('group_name_white_list', [])
|
group_name_white_list = config.get("group_name_white_list", [])
|
||||||
group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
|
group_name_keyword_white_list = config.get(
|
||||||
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
|
"group_name_keyword_white_list", []
|
||||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
)
|
||||||
|
if any(
|
||||||
|
[
|
||||||
|
group_name in group_name_white_list,
|
||||||
|
"ALL_GROUP" in group_name_white_list,
|
||||||
|
check_contain(group_name, group_name_keyword_white_list),
|
||||||
|
]
|
||||||
|
):
|
||||||
|
group_chat_in_one_session = conf().get(
|
||||||
|
"group_chat_in_one_session", []
|
||||||
|
)
|
||||||
session_id = cmsg.actual_user_id
|
session_id = cmsg.actual_user_id
|
||||||
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
|
if any(
|
||||||
|
[
|
||||||
|
group_name in group_chat_in_one_session,
|
||||||
|
"ALL_GROUP" in group_chat_in_one_session,
|
||||||
|
]
|
||||||
|
):
|
||||||
session_id = group_id
|
session_id = group_id
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
context['session_id'] = session_id
|
context["session_id"] = session_id
|
||||||
context['receiver'] = group_id
|
context["receiver"] = group_id
|
||||||
else:
|
else:
|
||||||
context['session_id'] = cmsg.other_user_id
|
context["session_id"] = cmsg.other_user_id
|
||||||
context['receiver'] = cmsg.other_user_id
|
context["receiver"] = cmsg.other_user_id
|
||||||
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {'channel': self, 'context': context}))
|
e_context = PluginManager().emit_event(
|
||||||
context = e_context['context']
|
EventContext(
|
||||||
|
Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
context = e_context["context"]
|
||||||
if e_context.is_pass() or context is None:
|
if e_context.is_pass() or context is None:
|
||||||
return context
|
return context
|
||||||
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True):
|
if cmsg.from_user_id == self.user_id and not config.get(
|
||||||
|
"trigger_by_self", True
|
||||||
|
):
|
||||||
logger.debug("[WX]self message skipped")
|
logger.debug("[WX]self message skipped")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 消息内容匹配过程,并处理content
|
# 消息内容匹配过程,并处理content
|
||||||
if ctype == ContextType.TEXT:
|
if ctype == ContextType.TEXT:
|
||||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
||||||
logger.debug("[WX]reference query skipped")
|
logger.debug("[WX]reference query skipped")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if context.get("isgroup", False): # 群聊
|
if context.get("isgroup", False): # 群聊
|
||||||
# 校验关键字
|
# 校验关键字
|
||||||
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
|
match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
|
||||||
match_contain = check_contain(content, conf().get('group_chat_keyword'))
|
match_contain = check_contain(content, conf().get("group_chat_keyword"))
|
||||||
flag = False
|
flag = False
|
||||||
if match_prefix is not None or match_contain is not None:
|
if match_prefix is not None or match_contain is not None:
|
||||||
flag = True
|
flag = True
|
||||||
if match_prefix:
|
if match_prefix:
|
||||||
content = content.replace(match_prefix, '', 1).strip()
|
content = content.replace(match_prefix, "", 1).strip()
|
||||||
if context['msg'].is_at:
|
if context["msg"].is_at:
|
||||||
logger.info("[WX]receive group at")
|
logger.info("[WX]receive group at")
|
||||||
if not conf().get("group_at_off", False):
|
if not conf().get("group_at_off", False):
|
||||||
flag = True
|
flag = True
|
||||||
pattern = f'@{self.name}(\u2005|\u0020)'
|
pattern = f"@{self.name}(\u2005|\u0020)"
|
||||||
content = re.sub(pattern, r'', content)
|
content = re.sub(pattern, r"", content)
|
||||||
|
|
||||||
if not flag:
|
if not flag:
|
||||||
if context["origin_ctype"] == ContextType.VOICE:
|
if context["origin_ctype"] == ContextType.VOICE:
|
||||||
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
logger.info(
|
||||||
|
"[WX]receive group voice, but checkprefix didn't match"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
else: # 单聊
|
else: # 单聊
|
||||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix',['']))
|
match_prefix = check_prefix(
|
||||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
content, conf().get("single_chat_prefix", [""])
|
||||||
content = content.replace(match_prefix, '', 1).strip()
|
)
|
||||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||||
|
content = content.replace(match_prefix, "", 1).strip()
|
||||||
|
elif (
|
||||||
|
context["origin_ctype"] == ContextType.VOICE
|
||||||
|
): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||||
if img_match_prefix:
|
if img_match_prefix:
|
||||||
content = content.replace(img_match_prefix, '', 1)
|
content = content.replace(img_match_prefix, "", 1)
|
||||||
context.type = ContextType.IMAGE_CREATE
|
context.type = ContextType.IMAGE_CREATE
|
||||||
else:
|
else:
|
||||||
context.type = ContextType.TEXT
|
context.type = ContextType.TEXT
|
||||||
context.content = content.strip()
|
context.content = content.strip()
|
||||||
if 'desire_rtype' not in context and conf().get('always_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
if (
|
||||||
context['desire_rtype'] = ReplyType.VOICE
|
"desire_rtype" not in context
|
||||||
elif context.type == ContextType.VOICE:
|
and conf().get("always_reply_voice")
|
||||||
if 'desire_rtype' not in context and conf().get('voice_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
|
||||||
context['desire_rtype'] = ReplyType.VOICE
|
):
|
||||||
|
context["desire_rtype"] = ReplyType.VOICE
|
||||||
|
elif context.type == ContextType.VOICE:
|
||||||
|
if (
|
||||||
|
"desire_rtype" not in context
|
||||||
|
and conf().get("voice_reply_voice")
|
||||||
|
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
|
||||||
|
):
|
||||||
|
context["desire_rtype"] = ReplyType.VOICE
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def _handle(self, context: Context):
|
def _handle(self, context: Context):
|
||||||
if context is None or not context.content:
|
if context is None or not context.content:
|
||||||
return
|
return
|
||||||
logger.debug('[WX] ready to handle context: {}'.format(context))
|
logger.debug("[WX] ready to handle context: {}".format(context))
|
||||||
# reply的构建步骤
|
# reply的构建步骤
|
||||||
reply = self._generate_reply(context)
|
reply = self._generate_reply(context)
|
||||||
|
|
||||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
|
logger.debug("[WX] ready to decorate reply: {}".format(reply))
|
||||||
# reply的包装步骤
|
# reply的包装步骤
|
||||||
reply = self._decorate_reply(context, reply)
|
reply = self._decorate_reply(context, reply)
|
||||||
|
|
||||||
@@ -139,20 +174,31 @@ class ChatChannel(Channel):
|
|||||||
self._send_reply(context, reply)
|
self._send_reply(context, reply)
|
||||||
|
|
||||||
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
||||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
e_context = PluginManager().emit_event(
|
||||||
'channel': self, 'context': context, 'reply': reply}))
|
EventContext(
|
||||||
reply = e_context['reply']
|
Event.ON_HANDLE_CONTEXT,
|
||||||
|
{"channel": self, "context": context, "reply": reply},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
reply = e_context["reply"]
|
||||||
if not e_context.is_pass():
|
if not e_context.is_pass():
|
||||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
|
logger.debug(
|
||||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
"[WX] ready to handle context: type={}, content={}".format(
|
||||||
|
context.type, context.content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
context.type == ContextType.TEXT
|
||||||
|
or context.type == ContextType.IMAGE_CREATE
|
||||||
|
): # 文字和图片消息
|
||||||
reply = super().build_reply_content(context.content, context)
|
reply = super().build_reply_content(context.content, context)
|
||||||
elif context.type == ContextType.VOICE: # 语音消息
|
elif context.type == ContextType.VOICE: # 语音消息
|
||||||
cmsg = context['msg']
|
cmsg = context["msg"]
|
||||||
cmsg.prepare()
|
cmsg.prepare()
|
||||||
file_path = context.content
|
file_path = context.content
|
||||||
wav_path = os.path.splitext(file_path)[0] + '.wav'
|
wav_path = os.path.splitext(file_path)[0] + ".wav"
|
||||||
try:
|
try:
|
||||||
any_to_wav(file_path, wav_path)
|
any_to_wav(file_path, wav_path)
|
||||||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
||||||
logger.warning("[WX]any to wav error, use raw path. " + str(e))
|
logger.warning("[WX]any to wav error, use raw path. " + str(e))
|
||||||
wav_path = file_path
|
wav_path = file_path
|
||||||
@@ -169,7 +215,8 @@ class ChatChannel(Channel):
|
|||||||
|
|
||||||
if reply.type == ReplyType.TEXT:
|
if reply.type == ReplyType.TEXT:
|
||||||
new_context = self._compose_context(
|
new_context = self._compose_context(
|
||||||
ContextType.TEXT, reply.content, **context.kwargs)
|
ContextType.TEXT, reply.content, **context.kwargs
|
||||||
|
)
|
||||||
if new_context:
|
if new_context:
|
||||||
reply = self._generate_reply(new_context)
|
reply = self._generate_reply(new_context)
|
||||||
else:
|
else:
|
||||||
@@ -177,18 +224,21 @@ class ChatChannel(Channel):
|
|||||||
elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑
|
elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
logger.error('[WX] unknown context type: {}'.format(context.type))
|
logger.error("[WX] unknown context type: {}".format(context.type))
|
||||||
return
|
return
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
|
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
|
||||||
if reply and reply.type:
|
if reply and reply.type:
|
||||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
|
e_context = PluginManager().emit_event(
|
||||||
'channel': self, 'context': context, 'reply': reply}))
|
EventContext(
|
||||||
reply = e_context['reply']
|
Event.ON_DECORATE_REPLY,
|
||||||
desire_rtype = context.get('desire_rtype')
|
{"channel": self, "context": context, "reply": reply},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
reply = e_context["reply"]
|
||||||
|
desire_rtype = context.get("desire_rtype")
|
||||||
if not e_context.is_pass() and reply and reply.type:
|
if not e_context.is_pass() and reply and reply.type:
|
||||||
|
|
||||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
||||||
logger.error("[WX]reply type not support: " + str(reply.type))
|
logger.error("[WX]reply type not support: " + str(reply.type))
|
||||||
reply.type = ReplyType.ERROR
|
reply.type = ReplyType.ERROR
|
||||||
@@ -196,59 +246,91 @@ class ChatChannel(Channel):
|
|||||||
|
|
||||||
if reply.type == ReplyType.TEXT:
|
if reply.type == ReplyType.TEXT:
|
||||||
reply_text = reply.content
|
reply_text = reply.content
|
||||||
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
if (
|
||||||
|
desire_rtype == ReplyType.VOICE
|
||||||
|
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
|
||||||
|
):
|
||||||
reply = super().build_text_to_voice(reply.content)
|
reply = super().build_text_to_voice(reply.content)
|
||||||
return self._decorate_reply(context, reply)
|
return self._decorate_reply(context, reply)
|
||||||
if context.get("isgroup", False):
|
if context.get("isgroup", False):
|
||||||
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
|
reply_text = (
|
||||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
|
"@"
|
||||||
|
+ context["msg"].actual_user_nickname
|
||||||
|
+ " "
|
||||||
|
+ reply_text.strip()
|
||||||
|
)
|
||||||
|
reply_text = (
|
||||||
|
conf().get("group_chat_reply_prefix", "") + reply_text
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
|
reply_text = (
|
||||||
|
conf().get("single_chat_reply_prefix", "") + reply_text
|
||||||
|
)
|
||||||
reply.content = reply_text
|
reply.content = reply_text
|
||||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||||
reply.content = "["+str(reply.type)+"]\n" + reply.content
|
reply.content = "[" + str(reply.type) + "]\n" + reply.content
|
||||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
elif (
|
||||||
|
reply.type == ReplyType.IMAGE_URL
|
||||||
|
or reply.type == ReplyType.VOICE
|
||||||
|
or reply.type == ReplyType.IMAGE
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
logger.error('[WX] unknown reply type: {}'.format(reply.type))
|
logger.error("[WX] unknown reply type: {}".format(reply.type))
|
||||||
return
|
return
|
||||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
if (
|
||||||
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
|
desire_rtype
|
||||||
|
and desire_rtype != reply.type
|
||||||
|
and reply.type not in [ReplyType.ERROR, ReplyType.INFO]
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"[WX] desire_rtype: {}, but reply type: {}".format(
|
||||||
|
context.get("desire_rtype"), reply.type
|
||||||
|
)
|
||||||
|
)
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def _send_reply(self, context: Context, reply: Reply):
|
def _send_reply(self, context: Context, reply: Reply):
|
||||||
if reply and reply.type:
|
if reply and reply.type:
|
||||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
|
e_context = PluginManager().emit_event(
|
||||||
'channel': self, 'context': context, 'reply': reply}))
|
EventContext(
|
||||||
reply = e_context['reply']
|
Event.ON_SEND_REPLY,
|
||||||
|
{"channel": self, "context": context, "reply": reply},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
reply = e_context["reply"]
|
||||||
if not e_context.is_pass() and reply and reply.type:
|
if not e_context.is_pass() and reply and reply.type:
|
||||||
logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context))
|
logger.debug(
|
||||||
|
"[WX] ready to send reply: {}, context: {}".format(reply, context)
|
||||||
|
)
|
||||||
self._send(reply, context)
|
self._send(reply, context)
|
||||||
|
|
||||||
def _send(self, reply: Reply, context: Context, retry_cnt = 0):
|
def _send(self, reply: Reply, context: Context, retry_cnt=0):
|
||||||
try:
|
try:
|
||||||
self.send(reply, context)
|
self.send(reply, context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error('[WX] sendMsg error: {}'.format(str(e)))
|
logger.error("[WX] sendMsg error: {}".format(str(e)))
|
||||||
if isinstance(e, NotImplementedError):
|
if isinstance(e, NotImplementedError):
|
||||||
return
|
return
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
if retry_cnt < 2:
|
if retry_cnt < 2:
|
||||||
time.sleep(3+3*retry_cnt)
|
time.sleep(3 + 3 * retry_cnt)
|
||||||
self._send(reply, context, retry_cnt+1)
|
self._send(reply, context, retry_cnt + 1)
|
||||||
|
|
||||||
def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数
|
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
|
||||||
logger.debug("Worker return success, session_id = {}".format(session_id))
|
logger.debug("Worker return success, session_id = {}".format(session_id))
|
||||||
|
|
||||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
|
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
|
||||||
logger.exception("Worker return exception: {}".format(exception))
|
logger.exception("Worker return exception: {}".format(exception))
|
||||||
|
|
||||||
def _thread_pool_callback(self, session_id, **kwargs):
|
def _thread_pool_callback(self, session_id, **kwargs):
|
||||||
def func(worker:Future):
|
def func(worker: Future):
|
||||||
try:
|
try:
|
||||||
worker_exception = worker.exception()
|
worker_exception = worker.exception()
|
||||||
if worker_exception:
|
if worker_exception:
|
||||||
self._fail_callback(session_id, exception = worker_exception, **kwargs)
|
self._fail_callback(
|
||||||
|
session_id, exception=worker_exception, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._success_callback(session_id, **kwargs)
|
self._success_callback(session_id, **kwargs)
|
||||||
except CancelledError as e:
|
except CancelledError as e:
|
||||||
@@ -257,15 +339,19 @@ class ChatChannel(Channel):
|
|||||||
logger.exception("Worker raise exception: {}".format(e))
|
logger.exception("Worker raise exception: {}".format(e))
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.sessions[session_id][1].release()
|
self.sessions[session_id][1].release()
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
def produce(self, context: Context):
|
def produce(self, context: Context):
|
||||||
session_id = context['session_id']
|
session_id = context["session_id"]
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if session_id not in self.sessions:
|
if session_id not in self.sessions:
|
||||||
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 4))]
|
self.sessions[session_id] = [
|
||||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
Dequeue(),
|
||||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
|
||||||
|
]
|
||||||
|
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||||
|
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||||
else:
|
else:
|
||||||
self.sessions[session_id][0].put(context)
|
self.sessions[session_id][0].put(context)
|
||||||
|
|
||||||
@@ -276,44 +362,58 @@ class ChatChannel(Channel):
|
|||||||
session_ids = list(self.sessions.keys())
|
session_ids = list(self.sessions.keys())
|
||||||
for session_id in session_ids:
|
for session_id in session_ids:
|
||||||
context_queue, semaphore = self.sessions[session_id]
|
context_queue, semaphore = self.sessions[session_id]
|
||||||
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
|
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
||||||
if not context_queue.empty():
|
if not context_queue.empty():
|
||||||
context = context_queue.get()
|
context = context_queue.get()
|
||||||
logger.debug("[WX] consume context: {}".format(context))
|
logger.debug("[WX] consume context: {}".format(context))
|
||||||
future:Future = self.handler_pool.submit(self._handle, context)
|
future: Future = self.handler_pool.submit(
|
||||||
future.add_done_callback(self._thread_pool_callback(session_id, context = context))
|
self._handle, context
|
||||||
|
)
|
||||||
|
future.add_done_callback(
|
||||||
|
self._thread_pool_callback(session_id, context=context)
|
||||||
|
)
|
||||||
if session_id not in self.futures:
|
if session_id not in self.futures:
|
||||||
self.futures[session_id] = []
|
self.futures[session_id] = []
|
||||||
self.futures[session_id].append(future)
|
self.futures[session_id].append(future)
|
||||||
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
elif (
|
||||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
semaphore._initial_value == semaphore._value + 1
|
||||||
assert len(self.futures[session_id]) == 0, "thread pool error"
|
): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||||
|
self.futures[session_id] = [
|
||||||
|
t for t in self.futures[session_id] if not t.done()
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(self.futures[session_id]) == 0
|
||||||
|
), "thread pool error"
|
||||||
del self.sessions[session_id]
|
del self.sessions[session_id]
|
||||||
else:
|
else:
|
||||||
semaphore.release()
|
semaphore.release()
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
||||||
def cancel_session(self, session_id):
|
def cancel_session(self, session_id):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if session_id in self.sessions:
|
if session_id in self.sessions:
|
||||||
for future in self.futures[session_id]:
|
for future in self.futures[session_id]:
|
||||||
future.cancel()
|
future.cancel()
|
||||||
cnt = self.sessions[session_id][0].qsize()
|
cnt = self.sessions[session_id][0].qsize()
|
||||||
if cnt>0:
|
if cnt > 0:
|
||||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
logger.info(
|
||||||
|
"Cancel {} messages in session {}".format(cnt, session_id)
|
||||||
|
)
|
||||||
self.sessions[session_id][0] = Dequeue()
|
self.sessions[session_id][0] = Dequeue()
|
||||||
|
|
||||||
def cancel_all_session(self):
|
def cancel_all_session(self):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
for session_id in self.sessions:
|
for session_id in self.sessions:
|
||||||
for future in self.futures[session_id]:
|
for future in self.futures[session_id]:
|
||||||
future.cancel()
|
future.cancel()
|
||||||
cnt = self.sessions[session_id][0].qsize()
|
cnt = self.sessions[session_id][0].qsize()
|
||||||
if cnt>0:
|
if cnt > 0:
|
||||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
logger.info(
|
||||||
|
"Cancel {} messages in session {}".format(cnt, session_id)
|
||||||
|
)
|
||||||
self.sessions[session_id][0] = Dequeue()
|
self.sessions[session_id][0] = Dequeue()
|
||||||
|
|
||||||
|
|
||||||
def check_prefix(content, prefix_list):
|
def check_prefix(content, prefix_list):
|
||||||
if not prefix_list:
|
if not prefix_list:
|
||||||
@@ -323,6 +423,7 @@ def check_prefix(content, prefix_list):
|
|||||||
return prefix
|
return prefix
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_contain(content, keyword_list):
|
def check_contain(content, keyword_list):
|
||||||
if not keyword_list:
|
if not keyword_list:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
|
"""
|
||||||
"""
|
|
||||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
||||||
|
|
||||||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
||||||
@@ -20,7 +19,7 @@ other_user_id: 对方的id,如果你是发送者,那这个就是接收者id
|
|||||||
other_user_nickname: 同上
|
other_user_nickname: 同上
|
||||||
|
|
||||||
is_group: 是否是群消息 (群聊必填)
|
is_group: 是否是群消息 (群聊必填)
|
||||||
is_at: 是否被at
|
is_at: 是否被at
|
||||||
|
|
||||||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
||||||
actual_user_id: 实际发送者id (群聊必填)
|
actual_user_id: 实际发送者id (群聊必填)
|
||||||
@@ -34,20 +33,22 @@ _prepared: 是否已经调用过准备函数
|
|||||||
_rawmsg: 原始消息对象
|
_rawmsg: 原始消息对象
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(object):
|
class ChatMessage(object):
|
||||||
msg_id = None
|
msg_id = None
|
||||||
create_time = None
|
create_time = None
|
||||||
|
|
||||||
ctype = None
|
ctype = None
|
||||||
content = None
|
content = None
|
||||||
|
|
||||||
from_user_id = None
|
from_user_id = None
|
||||||
from_user_nickname = None
|
from_user_nickname = None
|
||||||
to_user_id = None
|
to_user_id = None
|
||||||
to_user_nickname = None
|
to_user_nickname = None
|
||||||
other_user_id = None
|
other_user_id = None
|
||||||
other_user_nickname = None
|
other_user_nickname = None
|
||||||
|
|
||||||
is_group = False
|
is_group = False
|
||||||
is_at = False
|
is_at = False
|
||||||
actual_user_id = None
|
actual_user_id = None
|
||||||
@@ -57,8 +58,7 @@ class ChatMessage(object):
|
|||||||
_prepared = False
|
_prepared = False
|
||||||
_rawmsg = None
|
_rawmsg = None
|
||||||
|
|
||||||
|
def __init__(self, _rawmsg):
|
||||||
def __init__(self,_rawmsg):
|
|
||||||
self._rawmsg = _rawmsg
|
self._rawmsg = _rawmsg
|
||||||
|
|
||||||
def prepare(self):
|
def prepare(self):
|
||||||
@@ -67,7 +67,7 @@ class ChatMessage(object):
|
|||||||
self._prepare_fn()
|
self._prepare_fn()
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format(
|
return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}".format(
|
||||||
self.msg_id,
|
self.msg_id,
|
||||||
self.create_time,
|
self.create_time,
|
||||||
self.ctype,
|
self.ctype,
|
||||||
@@ -82,4 +82,4 @@ class ChatMessage(object):
|
|||||||
self.is_at,
|
self.is_at,
|
||||||
self.actual_user_id,
|
self.actual_user_id,
|
||||||
self.actual_user_nickname,
|
self.actual_user_nickname,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,14 +1,23 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
from bridge.context import *
|
from bridge.context import *
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from channel.chat_channel import ChatChannel, check_prefix
|
from channel.chat_channel import ChatChannel, check_prefix
|
||||||
from channel.chat_message import ChatMessage
|
from channel.chat_message import ChatMessage
|
||||||
import sys
|
|
||||||
|
|
||||||
from config import conf
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from config import conf
|
||||||
|
|
||||||
|
|
||||||
class TerminalMessage(ChatMessage):
|
class TerminalMessage(ChatMessage):
|
||||||
def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
msg_id,
|
||||||
|
content,
|
||||||
|
ctype=ContextType.TEXT,
|
||||||
|
from_user_id="User",
|
||||||
|
to_user_id="Chatgpt",
|
||||||
|
other_user_id="Chatgpt",
|
||||||
|
):
|
||||||
self.msg_id = msg_id
|
self.msg_id = msg_id
|
||||||
self.ctype = ctype
|
self.ctype = ctype
|
||||||
self.content = content
|
self.content = content
|
||||||
@@ -16,6 +25,7 @@ class TerminalMessage(ChatMessage):
|
|||||||
self.to_user_id = to_user_id
|
self.to_user_id = to_user_id
|
||||||
self.other_user_id = other_user_id
|
self.other_user_id = other_user_id
|
||||||
|
|
||||||
|
|
||||||
class TerminalChannel(ChatChannel):
|
class TerminalChannel(ChatChannel):
|
||||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
||||||
|
|
||||||
@@ -23,14 +33,18 @@ class TerminalChannel(ChatChannel):
|
|||||||
print("\nBot:")
|
print("\nBot:")
|
||||||
if reply.type == ReplyType.IMAGE:
|
if reply.type == ReplyType.IMAGE:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
image_storage = reply.content
|
image_storage = reply.content
|
||||||
image_storage.seek(0)
|
image_storage.seek(0)
|
||||||
img = Image.open(image_storage)
|
img = Image.open(image_storage)
|
||||||
print("<IMAGE>")
|
print("<IMAGE>")
|
||||||
img.show()
|
img.show()
|
||||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||||
|
import io
|
||||||
|
|
||||||
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import requests,io
|
|
||||||
img_url = reply.content
|
img_url = reply.content
|
||||||
pic_res = requests.get(img_url, stream=True)
|
pic_res = requests.get(img_url, stream=True)
|
||||||
image_storage = io.BytesIO()
|
image_storage = io.BytesIO()
|
||||||
@@ -59,11 +73,13 @@ class TerminalChannel(ChatChannel):
|
|||||||
print("\nExiting...")
|
print("\nExiting...")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
msg_id += 1
|
msg_id += 1
|
||||||
trigger_prefixs = conf().get("single_chat_prefix",[""])
|
trigger_prefixs = conf().get("single_chat_prefix", [""])
|
||||||
if check_prefix(prompt, trigger_prefixs) is None:
|
if check_prefix(prompt, trigger_prefixs) is None:
|
||||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||||
|
|
||||||
context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt))
|
context = self._compose_context(
|
||||||
|
ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)
|
||||||
|
)
|
||||||
if context:
|
if context:
|
||||||
self.produce(context)
|
self.produce(context)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -4,40 +4,45 @@
|
|||||||
wechat channel
|
wechat channel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import requests
|
|
||||||
import io
|
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from bridge.context import *
|
||||||
|
from bridge.reply import *
|
||||||
from channel.chat_channel import ChatChannel
|
from channel.chat_channel import ChatChannel
|
||||||
from channel.wechat.wechat_message import *
|
from channel.wechat.wechat_message import *
|
||||||
from common.singleton import singleton
|
from common.expired_dict import ExpiredDict
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from common.singleton import singleton
|
||||||
|
from common.time_check import time_checker
|
||||||
|
from config import conf
|
||||||
from lib import itchat
|
from lib import itchat
|
||||||
from lib.itchat.content import *
|
from lib.itchat.content import *
|
||||||
from bridge.reply import *
|
|
||||||
from bridge.context import *
|
|
||||||
from config import conf
|
|
||||||
from common.time_check import time_checker
|
|
||||||
from common.expired_dict import ExpiredDict
|
|
||||||
from plugins import *
|
from plugins import *
|
||||||
|
|
||||||
@itchat.msg_register([TEXT,VOICE,PICTURE])
|
|
||||||
|
@itchat.msg_register([TEXT, VOICE, PICTURE])
|
||||||
def handler_single_msg(msg):
|
def handler_single_msg(msg):
|
||||||
# logger.debug("handler_single_msg: {}".format(msg))
|
# logger.debug("handler_single_msg: {}".format(msg))
|
||||||
if msg['Type'] == PICTURE and msg['MsgType'] == 47:
|
if msg["Type"] == PICTURE and msg["MsgType"] == 47:
|
||||||
return None
|
return None
|
||||||
WechatChannel().handle_single(WeChatMessage(msg))
|
WechatChannel().handle_single(WeChatMessage(msg))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@itchat.msg_register([TEXT,VOICE,PICTURE], isGroupChat=True)
|
|
||||||
|
@itchat.msg_register([TEXT, VOICE, PICTURE], isGroupChat=True)
|
||||||
def handler_group_msg(msg):
|
def handler_group_msg(msg):
|
||||||
if msg['Type'] == PICTURE and msg['MsgType'] == 47:
|
if msg["Type"] == PICTURE and msg["MsgType"] == 47:
|
||||||
return None
|
return None
|
||||||
WechatChannel().handle_group(WeChatMessage(msg,True))
|
WechatChannel().handle_group(WeChatMessage(msg, True))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _check(func):
|
def _check(func):
|
||||||
def wrapper(self, cmsg: ChatMessage):
|
def wrapper(self, cmsg: ChatMessage):
|
||||||
msgId = cmsg.msg_id
|
msgId = cmsg.msg_id
|
||||||
@@ -45,21 +50,27 @@ def _check(func):
|
|||||||
logger.info("Wechat message {} already received, ignore".format(msgId))
|
logger.info("Wechat message {} already received, ignore".format(msgId))
|
||||||
return
|
return
|
||||||
self.receivedMsgs[msgId] = cmsg
|
self.receivedMsgs[msgId] = cmsg
|
||||||
create_time = cmsg.create_time # 消息时间戳
|
create_time = cmsg.create_time # 消息时间戳
|
||||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
if (
|
||||||
|
conf().get("hot_reload") == True
|
||||||
|
and int(create_time) < int(time.time()) - 60
|
||||||
|
): # 跳过1分钟前的历史消息
|
||||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||||
return
|
return
|
||||||
return func(self, cmsg)
|
return func(self, cmsg)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
#可用的二维码生成接口
|
|
||||||
#https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
|
# 可用的二维码生成接口
|
||||||
#https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
|
# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
|
||||||
def qrCallback(uuid,status,qrcode):
|
# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
|
||||||
|
def qrCallback(uuid, status, qrcode):
|
||||||
# logger.debug("qrCallback: {} {}".format(uuid,status))
|
# logger.debug("qrCallback: {} {}".format(uuid,status))
|
||||||
if status == '0':
|
if status == "0":
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
img = Image.open(io.BytesIO(qrcode))
|
img = Image.open(io.BytesIO(qrcode))
|
||||||
_thread = threading.Thread(target=img.show, args=("QRCode",))
|
_thread = threading.Thread(target=img.show, args=("QRCode",))
|
||||||
_thread.setDaemon(True)
|
_thread.setDaemon(True)
|
||||||
@@ -68,35 +79,43 @@ def qrCallback(uuid,status,qrcode):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
import qrcode
|
import qrcode
|
||||||
|
|
||||||
url = f"https://login.weixin.qq.com/l/{uuid}"
|
url = f"https://login.weixin.qq.com/l/{uuid}"
|
||||||
|
|
||||||
qr_api1="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
|
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
|
||||||
qr_api2="https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
|
qr_api2 = (
|
||||||
qr_api3="https://api.pwmqr.com/qrcode/create/?url={}".format(url)
|
"https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(
|
||||||
qr_api4="https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
|
url
|
||||||
|
)
|
||||||
|
)
|
||||||
|
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
|
||||||
|
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(
|
||||||
|
url
|
||||||
|
)
|
||||||
print("You can also scan QRCode in any website below:")
|
print("You can also scan QRCode in any website below:")
|
||||||
print(qr_api3)
|
print(qr_api3)
|
||||||
print(qr_api4)
|
print(qr_api4)
|
||||||
print(qr_api2)
|
print(qr_api2)
|
||||||
print(qr_api1)
|
print(qr_api1)
|
||||||
|
|
||||||
qr = qrcode.QRCode(border=1)
|
qr = qrcode.QRCode(border=1)
|
||||||
qr.add_data(url)
|
qr.add_data(url)
|
||||||
qr.make(fit=True)
|
qr.make(fit=True)
|
||||||
qr.print_ascii(invert=True)
|
qr.print_ascii(invert=True)
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class WechatChannel(ChatChannel):
|
class WechatChannel(ChatChannel):
|
||||||
NOT_SUPPORT_REPLYTYPE = []
|
NOT_SUPPORT_REPLYTYPE = []
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.receivedMsgs = ExpiredDict(60*60*24)
|
self.receivedMsgs = ExpiredDict(60 * 60 * 24)
|
||||||
|
|
||||||
def startup(self):
|
def startup(self):
|
||||||
|
|
||||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||||
# login by scan QRCode
|
# login by scan QRCode
|
||||||
hotReload = conf().get('hot_reload', False)
|
hotReload = conf().get("hot_reload", False)
|
||||||
try:
|
try:
|
||||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
|
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -104,12 +123,18 @@ class WechatChannel(ChatChannel):
|
|||||||
logger.error("Hot reload failed, try to login without hot reload")
|
logger.error("Hot reload failed, try to login without hot reload")
|
||||||
itchat.logout()
|
itchat.logout()
|
||||||
os.remove("itchat.pkl")
|
os.remove("itchat.pkl")
|
||||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
|
itchat.auto_login(
|
||||||
|
enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
self.user_id = itchat.instance.storageClass.userName
|
self.user_id = itchat.instance.storageClass.userName
|
||||||
self.name = itchat.instance.storageClass.nickName
|
self.name = itchat.instance.storageClass.nickName
|
||||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
logger.info(
|
||||||
|
"Wechat login success, user_id: {}, nickname: {}".format(
|
||||||
|
self.user_id, self.name
|
||||||
|
)
|
||||||
|
)
|
||||||
# start message listener
|
# start message listener
|
||||||
itchat.run()
|
itchat.run()
|
||||||
|
|
||||||
@@ -127,24 +152,30 @@ class WechatChannel(ChatChannel):
|
|||||||
|
|
||||||
@time_checker
|
@time_checker
|
||||||
@_check
|
@_check
|
||||||
def handle_single(self, cmsg : ChatMessage):
|
def handle_single(self, cmsg: ChatMessage):
|
||||||
if cmsg.ctype == ContextType.VOICE:
|
if cmsg.ctype == ContextType.VOICE:
|
||||||
if conf().get('speech_recognition') != True:
|
if conf().get("speech_recognition") != True:
|
||||||
return
|
return
|
||||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||||
elif cmsg.ctype == ContextType.IMAGE:
|
elif cmsg.ctype == ContextType.IMAGE:
|
||||||
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
||||||
else:
|
else:
|
||||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
logger.debug(
|
||||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
"[WX]receive text msg: {}, cmsg={}".format(
|
||||||
|
json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg
|
||||||
|
)
|
||||||
|
)
|
||||||
|
context = self._compose_context(
|
||||||
|
cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg
|
||||||
|
)
|
||||||
if context:
|
if context:
|
||||||
self.produce(context)
|
self.produce(context)
|
||||||
|
|
||||||
@time_checker
|
@time_checker
|
||||||
@_check
|
@_check
|
||||||
def handle_group(self, cmsg : ChatMessage):
|
def handle_group(self, cmsg: ChatMessage):
|
||||||
if cmsg.ctype == ContextType.VOICE:
|
if cmsg.ctype == ContextType.VOICE:
|
||||||
if conf().get('speech_recognition') != True:
|
if conf().get("speech_recognition") != True:
|
||||||
return
|
return
|
||||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||||
elif cmsg.ctype == ContextType.IMAGE:
|
elif cmsg.ctype == ContextType.IMAGE:
|
||||||
@@ -152,23 +183,25 @@ class WechatChannel(ChatChannel):
|
|||||||
else:
|
else:
|
||||||
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||||
pass
|
pass
|
||||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
context = self._compose_context(
|
||||||
|
cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg
|
||||||
|
)
|
||||||
if context:
|
if context:
|
||||||
self.produce(context)
|
self.produce(context)
|
||||||
|
|
||||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||||
def send(self, reply: Reply, context: Context):
|
def send(self, reply: Reply, context: Context):
|
||||||
receiver = context["receiver"]
|
receiver = context["receiver"]
|
||||||
if reply.type == ReplyType.TEXT:
|
if reply.type == ReplyType.TEXT:
|
||||||
itchat.send(reply.content, toUserName=receiver)
|
itchat.send(reply.content, toUserName=receiver)
|
||||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||||
itchat.send(reply.content, toUserName=receiver)
|
itchat.send(reply.content, toUserName=receiver)
|
||||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||||
elif reply.type == ReplyType.VOICE:
|
elif reply.type == ReplyType.VOICE:
|
||||||
itchat.send_file(reply.content, toUserName=receiver)
|
itchat.send_file(reply.content, toUserName=receiver)
|
||||||
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
|
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
|
||||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||||
img_url = reply.content
|
img_url = reply.content
|
||||||
pic_res = requests.get(img_url, stream=True)
|
pic_res = requests.get(img_url, stream=True)
|
||||||
image_storage = io.BytesIO()
|
image_storage = io.BytesIO()
|
||||||
@@ -176,9 +209,9 @@ class WechatChannel(ChatChannel):
|
|||||||
image_storage.write(block)
|
image_storage.write(block)
|
||||||
image_storage.seek(0)
|
image_storage.seek(0)
|
||||||
itchat.send_image(image_storage, toUserName=receiver)
|
itchat.send_image(image_storage, toUserName=receiver)
|
||||||
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
|
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||||
image_storage = reply.content
|
image_storage = reply.content
|
||||||
image_storage.seek(0)
|
image_storage.seek(0)
|
||||||
itchat.send_image(image_storage, toUserName=receiver)
|
itchat.send_image(image_storage, toUserName=receiver)
|
||||||
logger.info('[WX] sendImage, receiver={}'.format(receiver))
|
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||||
|
|||||||
@@ -1,54 +1,54 @@
|
|||||||
|
|
||||||
|
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from channel.chat_message import ChatMessage
|
from channel.chat_message import ChatMessage
|
||||||
from common.tmp_dir import TmpDir
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
from lib.itchat.content import *
|
from common.tmp_dir import TmpDir
|
||||||
from lib import itchat
|
from lib import itchat
|
||||||
|
from lib.itchat.content import *
|
||||||
|
|
||||||
|
|
||||||
class WeChatMessage(ChatMessage):
|
class WeChatMessage(ChatMessage):
|
||||||
|
|
||||||
def __init__(self, itchat_msg, is_group=False):
|
def __init__(self, itchat_msg, is_group=False):
|
||||||
super().__init__( itchat_msg)
|
super().__init__(itchat_msg)
|
||||||
self.msg_id = itchat_msg['MsgId']
|
self.msg_id = itchat_msg["MsgId"]
|
||||||
self.create_time = itchat_msg['CreateTime']
|
self.create_time = itchat_msg["CreateTime"]
|
||||||
self.is_group = is_group
|
self.is_group = is_group
|
||||||
|
|
||||||
if itchat_msg['Type'] == TEXT:
|
if itchat_msg["Type"] == TEXT:
|
||||||
self.ctype = ContextType.TEXT
|
self.ctype = ContextType.TEXT
|
||||||
self.content = itchat_msg['Text']
|
self.content = itchat_msg["Text"]
|
||||||
elif itchat_msg['Type'] == VOICE:
|
elif itchat_msg["Type"] == VOICE:
|
||||||
self.ctype = ContextType.VOICE
|
self.ctype = ContextType.VOICE
|
||||||
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径
|
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||||
elif itchat_msg['Type'] == PICTURE and itchat_msg['MsgType'] == 3:
|
elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3:
|
||||||
self.ctype = ContextType.IMAGE
|
self.ctype = ContextType.IMAGE
|
||||||
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径
|
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unsupported message type: {}".format(itchat_msg['Type']))
|
raise NotImplementedError(
|
||||||
|
"Unsupported message type: {}".format(itchat_msg["Type"])
|
||||||
self.from_user_id = itchat_msg['FromUserName']
|
)
|
||||||
self.to_user_id = itchat_msg['ToUserName']
|
|
||||||
|
self.from_user_id = itchat_msg["FromUserName"]
|
||||||
|
self.to_user_id = itchat_msg["ToUserName"]
|
||||||
|
|
||||||
user_id = itchat.instance.storageClass.userName
|
user_id = itchat.instance.storageClass.userName
|
||||||
nickname = itchat.instance.storageClass.nickName
|
nickname = itchat.instance.storageClass.nickName
|
||||||
|
|
||||||
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
|
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
|
||||||
# 以下很繁琐,一句话总结:能填的都填了。
|
# 以下很繁琐,一句话总结:能填的都填了。
|
||||||
if self.from_user_id == user_id:
|
if self.from_user_id == user_id:
|
||||||
self.from_user_nickname = nickname
|
self.from_user_nickname = nickname
|
||||||
if self.to_user_id == user_id:
|
if self.to_user_id == user_id:
|
||||||
self.to_user_nickname = nickname
|
self.to_user_nickname = nickname
|
||||||
try: # 陌生人时候, 'User'字段可能不存在
|
try: # 陌生人时候, 'User'字段可能不存在
|
||||||
self.other_user_id = itchat_msg['User']['UserName']
|
self.other_user_id = itchat_msg["User"]["UserName"]
|
||||||
self.other_user_nickname = itchat_msg['User']['NickName']
|
self.other_user_nickname = itchat_msg["User"]["NickName"]
|
||||||
if self.other_user_id == self.from_user_id:
|
if self.other_user_id == self.from_user_id:
|
||||||
self.from_user_nickname = self.other_user_nickname
|
self.from_user_nickname = self.other_user_nickname
|
||||||
if self.other_user_id == self.to_user_id:
|
if self.other_user_id == self.to_user_id:
|
||||||
self.to_user_nickname = self.other_user_nickname
|
self.to_user_nickname = self.other_user_nickname
|
||||||
except KeyError as e: # 处理偶尔没有对方信息的情况
|
except KeyError as e: # 处理偶尔没有对方信息的情况
|
||||||
logger.warn("[WX]get other_user_id failed: " + str(e))
|
logger.warn("[WX]get other_user_id failed: " + str(e))
|
||||||
if self.from_user_id == user_id:
|
if self.from_user_id == user_id:
|
||||||
self.other_user_id = self.to_user_id
|
self.other_user_id = self.to_user_id
|
||||||
@@ -56,6 +56,6 @@ class WeChatMessage(ChatMessage):
|
|||||||
self.other_user_id = self.from_user_id
|
self.other_user_id = self.from_user_id
|
||||||
|
|
||||||
if self.is_group:
|
if self.is_group:
|
||||||
self.is_at = itchat_msg['IsAt']
|
self.is_at = itchat_msg["IsAt"]
|
||||||
self.actual_user_id = itchat_msg['ActualUserName']
|
self.actual_user_id = itchat_msg["ActualUserName"]
|
||||||
self.actual_user_nickname = itchat_msg['ActualNickName']
|
self.actual_user_nickname = itchat_msg["ActualNickName"]
|
||||||
|
|||||||
@@ -4,104 +4,118 @@
|
|||||||
wechaty channel
|
wechaty channel
|
||||||
Python Wechaty - https://github.com/wechaty/python-wechaty
|
Python Wechaty - https://github.com/wechaty/python-wechaty
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import asyncio
|
|
||||||
from bridge.context import Context
|
from wechaty import Contact, Wechaty
|
||||||
from wechaty_puppet import FileBox
|
|
||||||
from wechaty import Wechaty, Contact
|
|
||||||
from wechaty.user import Message
|
from wechaty.user import Message
|
||||||
from bridge.reply import *
|
from wechaty_puppet import FileBox
|
||||||
|
|
||||||
from bridge.context import *
|
from bridge.context import *
|
||||||
|
from bridge.context import Context
|
||||||
|
from bridge.reply import *
|
||||||
from channel.chat_channel import ChatChannel
|
from channel.chat_channel import ChatChannel
|
||||||
from channel.wechat.wechaty_message import WechatyMessage
|
from channel.wechat.wechaty_message import WechatyMessage
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
from common.singleton import singleton
|
from common.singleton import singleton
|
||||||
from config import conf
|
from config import conf
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from voice.audio_convert import any_to_sil
|
from voice.audio_convert import any_to_sil
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class WechatyChannel(ChatChannel):
|
class WechatyChannel(ChatChannel):
|
||||||
NOT_SUPPORT_REPLYTYPE = []
|
NOT_SUPPORT_REPLYTYPE = []
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def startup(self):
|
def startup(self):
|
||||||
config = conf()
|
config = conf()
|
||||||
token = config.get('wechaty_puppet_service_token')
|
token = config.get("wechaty_puppet_service_token")
|
||||||
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
|
os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token
|
||||||
asyncio.run(self.main())
|
asyncio.run(self.main())
|
||||||
|
|
||||||
async def main(self):
|
async def main(self):
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
#将asyncio的loop传入处理线程
|
# 将asyncio的loop传入处理线程
|
||||||
self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop)
|
self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop)
|
||||||
self.bot = Wechaty()
|
self.bot = Wechaty()
|
||||||
self.bot.on('login', self.on_login)
|
self.bot.on("login", self.on_login)
|
||||||
self.bot.on('message', self.on_message)
|
self.bot.on("message", self.on_message)
|
||||||
await self.bot.start()
|
await self.bot.start()
|
||||||
|
|
||||||
async def on_login(self, contact: Contact):
|
async def on_login(self, contact: Contact):
|
||||||
self.user_id = contact.contact_id
|
self.user_id = contact.contact_id
|
||||||
self.name = contact.name
|
self.name = contact.name
|
||||||
logger.info('[WX] login user={}'.format(contact))
|
logger.info("[WX] login user={}".format(contact))
|
||||||
|
|
||||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||||
def send(self, reply: Reply, context: Context):
|
def send(self, reply: Reply, context: Context):
|
||||||
receiver_id = context['receiver']
|
receiver_id = context["receiver"]
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
if context['isgroup']:
|
if context["isgroup"]:
|
||||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id),loop).result()
|
receiver = asyncio.run_coroutine_threadsafe(
|
||||||
|
self.bot.Room.find(receiver_id), loop
|
||||||
|
).result()
|
||||||
else:
|
else:
|
||||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id),loop).result()
|
receiver = asyncio.run_coroutine_threadsafe(
|
||||||
|
self.bot.Contact.find(receiver_id), loop
|
||||||
|
).result()
|
||||||
msg = None
|
msg = None
|
||||||
if reply.type == ReplyType.TEXT:
|
if reply.type == ReplyType.TEXT:
|
||||||
msg = reply.content
|
msg = reply.content
|
||||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||||
msg = reply.content
|
msg = reply.content
|
||||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||||
elif reply.type == ReplyType.VOICE:
|
elif reply.type == ReplyType.VOICE:
|
||||||
voiceLength = None
|
voiceLength = None
|
||||||
file_path = reply.content
|
file_path = reply.content
|
||||||
sil_file = os.path.splitext(file_path)[0] + '.sil'
|
sil_file = os.path.splitext(file_path)[0] + ".sil"
|
||||||
voiceLength = int(any_to_sil(file_path, sil_file))
|
voiceLength = int(any_to_sil(file_path, sil_file))
|
||||||
if voiceLength >= 60000:
|
if voiceLength >= 60000:
|
||||||
voiceLength = 60000
|
voiceLength = 60000
|
||||||
logger.info('[WX] voice too long, length={}, set to 60s'.format(voiceLength))
|
logger.info(
|
||||||
|
"[WX] voice too long, length={}, set to 60s".format(voiceLength)
|
||||||
|
)
|
||||||
# 发送语音
|
# 发送语音
|
||||||
t = int(time.time())
|
t = int(time.time())
|
||||||
msg = FileBox.from_file(sil_file, name=str(t) + '.sil')
|
msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
|
||||||
if voiceLength is not None:
|
if voiceLength is not None:
|
||||||
msg.metadata['voiceLength'] = voiceLength
|
msg.metadata["voiceLength"] = voiceLength
|
||||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||||
try:
|
try:
|
||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
if sil_file != file_path:
|
if sil_file != file_path:
|
||||||
os.remove(sil_file)
|
os.remove(sil_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver))
|
logger.info(
|
||||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
"[WX] sendVoice={}, receiver={}".format(reply.content, receiver)
|
||||||
|
)
|
||||||
|
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||||
img_url = reply.content
|
img_url = reply.content
|
||||||
t = int(time.time())
|
t = int(time.time())
|
||||||
msg = FileBox.from_url(url=img_url, name=str(t) + '.png')
|
msg = FileBox.from_url(url=img_url, name=str(t) + ".png")
|
||||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||||
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
|
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||||
image_storage = reply.content
|
image_storage = reply.content
|
||||||
image_storage.seek(0)
|
image_storage.seek(0)
|
||||||
t = int(time.time())
|
t = int(time.time())
|
||||||
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + '.png')
|
msg = FileBox.from_base64(
|
||||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
base64.b64encode(image_storage.read()), str(t) + ".png"
|
||||||
logger.info('[WX] sendImage, receiver={}'.format(receiver))
|
)
|
||||||
|
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||||
|
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||||
|
|
||||||
async def on_message(self, msg: Message):
|
async def on_message(self, msg: Message):
|
||||||
"""
|
"""
|
||||||
@@ -110,16 +124,16 @@ class WechatyChannel(ChatChannel):
|
|||||||
try:
|
try:
|
||||||
cmsg = await WechatyMessage(msg)
|
cmsg = await WechatyMessage(msg)
|
||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
logger.debug('[WX] {}'.format(e))
|
logger.debug("[WX] {}".format(e))
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception('[WX] {}'.format(e))
|
logger.exception("[WX] {}".format(e))
|
||||||
return
|
return
|
||||||
logger.debug('[WX] message:{}'.format(cmsg))
|
logger.debug("[WX] message:{}".format(cmsg))
|
||||||
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
|
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
|
||||||
isgroup = room is not None
|
isgroup = room is not None
|
||||||
ctype = cmsg.ctype
|
ctype = cmsg.ctype
|
||||||
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
|
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
|
||||||
if context:
|
if context:
|
||||||
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context))
|
logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context))
|
||||||
self.produce(context)
|
self.produce(context)
|
||||||
|
|||||||
@@ -1,17 +1,21 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from wechaty import MessageType
|
from wechaty import MessageType
|
||||||
|
from wechaty.user import Message
|
||||||
|
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from channel.chat_message import ChatMessage
|
from channel.chat_message import ChatMessage
|
||||||
from common.tmp_dir import TmpDir
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
from wechaty.user import Message
|
from common.tmp_dir import TmpDir
|
||||||
|
|
||||||
|
|
||||||
class aobject(object):
|
class aobject(object):
|
||||||
"""Inheriting this class allows you to define an async __init__.
|
"""Inheriting this class allows you to define an async __init__.
|
||||||
|
|
||||||
So you can create objects by doing something like `await MyClass(params)`
|
So you can create objects by doing something like `await MyClass(params)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def __new__(cls, *a, **kw):
|
async def __new__(cls, *a, **kw):
|
||||||
instance = super().__new__(cls)
|
instance = super().__new__(cls)
|
||||||
await instance.__init__(*a, **kw)
|
await instance.__init__(*a, **kw)
|
||||||
@@ -19,17 +23,18 @@ class aobject(object):
|
|||||||
|
|
||||||
async def __init__(self):
|
async def __init__(self):
|
||||||
pass
|
pass
|
||||||
class WechatyMessage(ChatMessage, aobject):
|
|
||||||
|
|
||||||
|
|
||||||
|
class WechatyMessage(ChatMessage, aobject):
|
||||||
async def __init__(self, wechaty_msg: Message):
|
async def __init__(self, wechaty_msg: Message):
|
||||||
super().__init__(wechaty_msg)
|
super().__init__(wechaty_msg)
|
||||||
|
|
||||||
room = wechaty_msg.room()
|
room = wechaty_msg.room()
|
||||||
|
|
||||||
self.msg_id = wechaty_msg.message_id
|
self.msg_id = wechaty_msg.message_id
|
||||||
self.create_time = wechaty_msg.payload.timestamp
|
self.create_time = wechaty_msg.payload.timestamp
|
||||||
self.is_group = room is not None
|
self.is_group = room is not None
|
||||||
|
|
||||||
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
||||||
self.ctype = ContextType.TEXT
|
self.ctype = ContextType.TEXT
|
||||||
self.content = wechaty_msg.text()
|
self.content = wechaty_msg.text()
|
||||||
@@ -40,12 +45,17 @@ class WechatyMessage(ChatMessage, aobject):
|
|||||||
|
|
||||||
def func():
|
def func():
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content),loop).result()
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
voice_file.to_file(self.content), loop
|
||||||
|
).result()
|
||||||
|
|
||||||
self._prepare_fn = func
|
self._prepare_fn = func
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
|
raise NotImplementedError(
|
||||||
|
"Unsupported message type: {}".format(wechaty_msg.type())
|
||||||
|
)
|
||||||
|
|
||||||
from_contact = wechaty_msg.talker() # 获取消息的发送者
|
from_contact = wechaty_msg.talker() # 获取消息的发送者
|
||||||
self.from_user_id = from_contact.contact_id
|
self.from_user_id = from_contact.contact_id
|
||||||
self.from_user_nickname = from_contact.name
|
self.from_user_nickname = from_contact.name
|
||||||
@@ -54,7 +64,7 @@ class WechatyMessage(ChatMessage, aobject):
|
|||||||
# wecahty: from是消息实际发送者, to:所在群
|
# wecahty: from是消息实际发送者, to:所在群
|
||||||
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
|
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
|
||||||
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
|
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
|
||||||
|
|
||||||
if self.is_group:
|
if self.is_group:
|
||||||
self.to_user_id = room.room_id
|
self.to_user_id = room.room_id
|
||||||
self.to_user_nickname = await room.topic()
|
self.to_user_nickname = await room.topic()
|
||||||
@@ -63,22 +73,22 @@ class WechatyMessage(ChatMessage, aobject):
|
|||||||
self.to_user_id = to_contact.contact_id
|
self.to_user_id = to_contact.contact_id
|
||||||
self.to_user_nickname = to_contact.name
|
self.to_user_nickname = to_contact.name
|
||||||
|
|
||||||
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
|
if (
|
||||||
|
self.is_group or wechaty_msg.is_self()
|
||||||
|
): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
|
||||||
self.other_user_id = self.to_user_id
|
self.other_user_id = self.to_user_id
|
||||||
self.other_user_nickname = self.to_user_nickname
|
self.other_user_nickname = self.to_user_nickname
|
||||||
else:
|
else:
|
||||||
self.other_user_id = self.from_user_id
|
self.other_user_id = self.from_user_id
|
||||||
self.other_user_nickname = self.from_user_nickname
|
self.other_user_nickname = self.from_user_nickname
|
||||||
|
|
||||||
|
if self.is_group: # wechaty群聊中,实际发送用户就是from_user
|
||||||
|
|
||||||
if self.is_group: # wechaty群聊中,实际发送用户就是from_user
|
|
||||||
self.is_at = await wechaty_msg.mention_self()
|
self.is_at = await wechaty_msg.mention_self()
|
||||||
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
|
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
|
||||||
name = wechaty_msg.wechaty.user_self().name
|
name = wechaty_msg.wechaty.user_self().name
|
||||||
pattern = f'@{name}(\u2005|\u0020)'
|
pattern = f"@{name}(\u2005|\u0020)"
|
||||||
if re.search(pattern,self.content):
|
if re.search(pattern, self.content):
|
||||||
logger.debug(f'wechaty message {self.msg_id} include at')
|
logger.debug(f"wechaty message {self.msg_id} include at")
|
||||||
self.is_at = True
|
self.is_at = True
|
||||||
|
|
||||||
self.actual_user_id = self.from_user_id
|
self.actual_user_id = self.from_user_id
|
||||||
|
|||||||
@@ -21,12 +21,12 @@ pip3 install web.py
|
|||||||
|
|
||||||
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
|
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
|
||||||
```
|
```
|
||||||
"channel_type": "wechatmp",
|
"channel_type": "wechatmp",
|
||||||
"wechatmp_token": "Token", # 微信公众平台的Token
|
"wechatmp_token": "Token", # 微信公众平台的Token
|
||||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
|
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
|
||||||
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要
|
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要
|
||||||
```
|
```
|
||||||
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口(443同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`需要修改相应的证书路径):
|
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口(443同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`需要修改相应的证书路径):
|
||||||
```
|
```
|
||||||
sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080
|
sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080
|
||||||
|
|||||||
@@ -1,46 +1,66 @@
|
|||||||
import web
|
|
||||||
import time
|
import time
|
||||||
import channel.wechatmp.reply as reply
|
|
||||||
|
import web
|
||||||
|
|
||||||
import channel.wechatmp.receive as receive
|
import channel.wechatmp.receive as receive
|
||||||
from config import conf
|
import channel.wechatmp.reply as reply
|
||||||
from common.log import logger
|
|
||||||
from bridge.context import *
|
from bridge.context import *
|
||||||
from channel.wechatmp.common import *
|
from channel.wechatmp.common import *
|
||||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||||
|
from common.log import logger
|
||||||
|
from config import conf
|
||||||
|
|
||||||
|
|
||||||
# This class is instantiated once per query
|
# This class is instantiated once per query
|
||||||
class Query():
|
class Query:
|
||||||
|
|
||||||
def GET(self):
|
def GET(self):
|
||||||
return verify_server(web.input())
|
return verify_server(web.input())
|
||||||
|
|
||||||
def POST(self):
|
def POST(self):
|
||||||
# Make sure to return the instance that first created, @singleton will do that.
|
# Make sure to return the instance that first created, @singleton will do that.
|
||||||
channel = WechatMPChannel()
|
channel = WechatMPChannel()
|
||||||
try:
|
try:
|
||||||
webData = web.data()
|
webData = web.data()
|
||||||
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
|
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
|
||||||
wechatmp_msg = receive.parse_xml(webData)
|
wechatmp_msg = receive.parse_xml(webData)
|
||||||
if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice':
|
if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice":
|
||||||
from_user = wechatmp_msg.from_user_id
|
from_user = wechatmp_msg.from_user_id
|
||||||
message = wechatmp_msg.content.decode("utf-8")
|
message = wechatmp_msg.content.decode("utf-8")
|
||||||
message_id = wechatmp_msg.msg_id
|
message_id = wechatmp_msg.msg_id
|
||||||
|
|
||||||
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
|
logger.info(
|
||||||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
|
"[wechatmp] {}:{} Receive post query {} {}: {}".format(
|
||||||
|
web.ctx.env.get("REMOTE_ADDR"),
|
||||||
|
web.ctx.env.get("REMOTE_PORT"),
|
||||||
|
from_user,
|
||||||
|
message_id,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
context = channel._compose_context(
|
||||||
|
ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg
|
||||||
|
)
|
||||||
if context:
|
if context:
|
||||||
# set private openai_api_key
|
# set private openai_api_key
|
||||||
# if from_user is not changed in itchat, this can be placed at chat_channel
|
# if from_user is not changed in itchat, this can be placed at chat_channel
|
||||||
user_data = conf().get_user_data(from_user)
|
user_data = conf().get_user_data(from_user)
|
||||||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
|
context["openai_api_key"] = user_data.get(
|
||||||
|
"openai_api_key"
|
||||||
|
) # None or user openai_api_key
|
||||||
channel.produce(context)
|
channel.produce(context)
|
||||||
# The reply will be sent by channel.send() in another thread
|
# The reply will be sent by channel.send() in another thread
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
elif wechatmp_msg.msg_type == 'event':
|
elif wechatmp_msg.msg_type == "event":
|
||||||
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.Event, wechatmp_msg.from_user_id))
|
logger.info(
|
||||||
|
"[wechatmp] Event {} from {}".format(
|
||||||
|
wechatmp_msg.Event, wechatmp_msg.from_user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
content = subscribe_msg()
|
content = subscribe_msg()
|
||||||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
|
replyMsg = reply.TextMsg(
|
||||||
|
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
|
||||||
|
)
|
||||||
return replyMsg.send()
|
return replyMsg.send()
|
||||||
else:
|
else:
|
||||||
logger.info("暂且不处理")
|
logger.info("暂且不处理")
|
||||||
@@ -48,4 +68,3 @@ class Query():
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception(exc)
|
logger.exception(exc)
|
||||||
return exc
|
return exc
|
||||||
|
|
||||||
|
|||||||
@@ -1,81 +1,117 @@
|
|||||||
import web
|
|
||||||
import time
|
import time
|
||||||
import channel.wechatmp.reply as reply
|
|
||||||
|
import web
|
||||||
|
|
||||||
import channel.wechatmp.receive as receive
|
import channel.wechatmp.receive as receive
|
||||||
from config import conf
|
import channel.wechatmp.reply as reply
|
||||||
from common.log import logger
|
|
||||||
from bridge.context import *
|
from bridge.context import *
|
||||||
from channel.wechatmp.common import *
|
from channel.wechatmp.common import *
|
||||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||||
|
from common.log import logger
|
||||||
|
from config import conf
|
||||||
|
|
||||||
|
|
||||||
# This class is instantiated once per query
|
# This class is instantiated once per query
|
||||||
class Query():
|
class Query:
|
||||||
|
|
||||||
def GET(self):
|
def GET(self):
|
||||||
return verify_server(web.input())
|
return verify_server(web.input())
|
||||||
|
|
||||||
def POST(self):
|
def POST(self):
|
||||||
# Make sure to return the instance that first created, @singleton will do that.
|
# Make sure to return the instance that first created, @singleton will do that.
|
||||||
channel = WechatMPChannel()
|
channel = WechatMPChannel()
|
||||||
try:
|
try:
|
||||||
query_time = time.time()
|
query_time = time.time()
|
||||||
webData = web.data()
|
webData = web.data()
|
||||||
logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
|
logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
|
||||||
wechatmp_msg = receive.parse_xml(webData)
|
wechatmp_msg = receive.parse_xml(webData)
|
||||||
if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice':
|
if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice":
|
||||||
from_user = wechatmp_msg.from_user_id
|
from_user = wechatmp_msg.from_user_id
|
||||||
to_user = wechatmp_msg.to_user_id
|
to_user = wechatmp_msg.to_user_id
|
||||||
message = wechatmp_msg.content.decode("utf-8")
|
message = wechatmp_msg.content.decode("utf-8")
|
||||||
message_id = wechatmp_msg.msg_id
|
message_id = wechatmp_msg.msg_id
|
||||||
|
|
||||||
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
|
logger.info(
|
||||||
|
"[wechatmp] {}:{} Receive post query {} {}: {}".format(
|
||||||
|
web.ctx.env.get("REMOTE_ADDR"),
|
||||||
|
web.ctx.env.get("REMOTE_PORT"),
|
||||||
|
from_user,
|
||||||
|
message_id,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
)
|
||||||
supported = True
|
supported = True
|
||||||
if "【收到不支持的消息类型,暂无法显示】" in message:
|
if "【收到不支持的消息类型,暂无法显示】" in message:
|
||||||
supported = False # not supported, used to refresh
|
supported = False # not supported, used to refresh
|
||||||
cache_key = from_user
|
cache_key = from_user
|
||||||
|
|
||||||
reply_text = ""
|
reply_text = ""
|
||||||
# New request
|
# New request
|
||||||
if cache_key not in channel.cache_dict and cache_key not in channel.running:
|
if (
|
||||||
|
cache_key not in channel.cache_dict
|
||||||
|
and cache_key not in channel.running
|
||||||
|
):
|
||||||
# The first query begin, reset the cache
|
# The first query begin, reset the cache
|
||||||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
|
context = channel._compose_context(
|
||||||
logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg))
|
ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg
|
||||||
if message_id in channel.received_msgs: # received and finished
|
)
|
||||||
|
logger.debug(
|
||||||
|
"[wechatmp] context: {} {}".format(context, wechatmp_msg)
|
||||||
|
)
|
||||||
|
if message_id in channel.received_msgs: # received and finished
|
||||||
# no return because of bandwords or other reasons
|
# no return because of bandwords or other reasons
|
||||||
return "success"
|
return "success"
|
||||||
if supported and context:
|
if supported and context:
|
||||||
# set private openai_api_key
|
# set private openai_api_key
|
||||||
# if from_user is not changed in itchat, this can be placed at chat_channel
|
# if from_user is not changed in itchat, this can be placed at chat_channel
|
||||||
user_data = conf().get_user_data(from_user)
|
user_data = conf().get_user_data(from_user)
|
||||||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
|
context["openai_api_key"] = user_data.get(
|
||||||
|
"openai_api_key"
|
||||||
|
) # None or user openai_api_key
|
||||||
channel.received_msgs[message_id] = wechatmp_msg
|
channel.received_msgs[message_id] = wechatmp_msg
|
||||||
channel.running.add(cache_key)
|
channel.running.add(cache_key)
|
||||||
channel.produce(context)
|
channel.produce(context)
|
||||||
else:
|
else:
|
||||||
trigger_prefix = conf().get('single_chat_prefix',[''])[0]
|
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
|
||||||
if trigger_prefix or not supported:
|
if trigger_prefix or not supported:
|
||||||
if trigger_prefix:
|
if trigger_prefix:
|
||||||
content = textwrap.dedent(f"""\
|
content = textwrap.dedent(
|
||||||
|
f"""\
|
||||||
请输入'{trigger_prefix}'接你想说的话跟我说话。
|
请输入'{trigger_prefix}'接你想说的话跟我说话。
|
||||||
例如:
|
例如:
|
||||||
{trigger_prefix}你好,很高兴见到你。""")
|
{trigger_prefix}你好,很高兴见到你。"""
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
content = textwrap.dedent("""\
|
content = textwrap.dedent(
|
||||||
|
"""\
|
||||||
你好,很高兴见到你。
|
你好,很高兴见到你。
|
||||||
请跟我说话吧。""")
|
请跟我说话吧。"""
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f"[wechatmp] unknown error")
|
logger.error(f"[wechatmp] unknown error")
|
||||||
content = textwrap.dedent("""\
|
content = textwrap.dedent(
|
||||||
未知错误,请稍后再试""")
|
"""\
|
||||||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
|
未知错误,请稍后再试"""
|
||||||
|
)
|
||||||
|
replyMsg = reply.TextMsg(
|
||||||
|
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
|
||||||
|
)
|
||||||
return replyMsg.send()
|
return replyMsg.send()
|
||||||
channel.query1[cache_key] = False
|
channel.query1[cache_key] = False
|
||||||
channel.query2[cache_key] = False
|
channel.query2[cache_key] = False
|
||||||
channel.query3[cache_key] = False
|
channel.query3[cache_key] = False
|
||||||
# User request again, and the answer is not ready
|
# User request again, and the answer is not ready
|
||||||
elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True:
|
elif (
|
||||||
channel.query1[cache_key] = False #To improve waiting experience, this can be set to True.
|
cache_key in channel.running
|
||||||
channel.query2[cache_key] = False #To improve waiting experience, this can be set to True.
|
and channel.query1.get(cache_key) == True
|
||||||
|
and channel.query2.get(cache_key) == True
|
||||||
|
and channel.query3.get(cache_key) == True
|
||||||
|
):
|
||||||
|
channel.query1[
|
||||||
|
cache_key
|
||||||
|
] = False # To improve waiting experience, this can be set to True.
|
||||||
|
channel.query2[
|
||||||
|
cache_key
|
||||||
|
] = False # To improve waiting experience, this can be set to True.
|
||||||
channel.query3[cache_key] = False
|
channel.query3[cache_key] = False
|
||||||
# User request again, and the answer is ready
|
# User request again, and the answer is ready
|
||||||
elif cache_key in channel.cache_dict:
|
elif cache_key in channel.cache_dict:
|
||||||
@@ -84,7 +120,9 @@ class Query():
|
|||||||
channel.query2[cache_key] = True
|
channel.query2[cache_key] = True
|
||||||
channel.query3[cache_key] = True
|
channel.query3[cache_key] = True
|
||||||
|
|
||||||
assert not (cache_key in channel.cache_dict and cache_key in channel.running)
|
assert not (
|
||||||
|
cache_key in channel.cache_dict and cache_key in channel.running
|
||||||
|
)
|
||||||
|
|
||||||
if channel.query1.get(cache_key) == False:
|
if channel.query1.get(cache_key) == False:
|
||||||
# The first query from wechat official server
|
# The first query from wechat official server
|
||||||
@@ -128,14 +166,20 @@ class Query():
|
|||||||
# Have waiting for 3x5 seconds
|
# Have waiting for 3x5 seconds
|
||||||
# return timeout message
|
# return timeout message
|
||||||
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
||||||
logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id))
|
logger.info(
|
||||||
|
"[wechatmp] Three queries has finished For {}: {}".format(
|
||||||
|
from_user, message_id
|
||||||
|
)
|
||||||
|
)
|
||||||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
||||||
return replyPost
|
return replyPost
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if (
|
||||||
if cache_key not in channel.cache_dict and cache_key not in channel.running:
|
cache_key not in channel.cache_dict
|
||||||
|
and cache_key not in channel.running
|
||||||
|
):
|
||||||
# no return because of bandwords or other reasons
|
# no return because of bandwords or other reasons
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
@@ -147,26 +191,42 @@ class Query():
|
|||||||
|
|
||||||
if cache_key in channel.cache_dict:
|
if cache_key in channel.cache_dict:
|
||||||
content = channel.cache_dict[cache_key]
|
content = channel.cache_dict[cache_key]
|
||||||
if len(content.encode('utf8'))<=MAX_UTF8_LEN:
|
if len(content.encode("utf8")) <= MAX_UTF8_LEN:
|
||||||
reply_text = channel.cache_dict[cache_key]
|
reply_text = channel.cache_dict[cache_key]
|
||||||
channel.cache_dict.pop(cache_key)
|
channel.cache_dict.pop(cache_key)
|
||||||
else:
|
else:
|
||||||
continue_text = "\n【未完待续,回复任意文字以继续】"
|
continue_text = "\n【未完待续,回复任意文字以继续】"
|
||||||
splits = split_string_by_utf8_length(content, MAX_UTF8_LEN - len(continue_text.encode('utf-8')), max_split= 1)
|
splits = split_string_by_utf8_length(
|
||||||
|
content,
|
||||||
|
MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
|
||||||
|
max_split=1,
|
||||||
|
)
|
||||||
reply_text = splits[0] + continue_text
|
reply_text = splits[0] + continue_text
|
||||||
channel.cache_dict[cache_key] = splits[1]
|
channel.cache_dict[cache_key] = splits[1]
|
||||||
logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text))
|
logger.info(
|
||||||
|
"[wechatmp] {}:{} Do send {}".format(
|
||||||
|
web.ctx.env.get("REMOTE_ADDR"),
|
||||||
|
web.ctx.env.get("REMOTE_PORT"),
|
||||||
|
reply_text,
|
||||||
|
)
|
||||||
|
)
|
||||||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
||||||
return replyPost
|
return replyPost
|
||||||
|
|
||||||
elif wechatmp_msg.msg_type == 'event':
|
elif wechatmp_msg.msg_type == "event":
|
||||||
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.content, wechatmp_msg.from_user_id))
|
logger.info(
|
||||||
|
"[wechatmp] Event {} from {}".format(
|
||||||
|
wechatmp_msg.content, wechatmp_msg.from_user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
content = subscribe_msg()
|
content = subscribe_msg()
|
||||||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
|
replyMsg = reply.TextMsg(
|
||||||
|
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
|
||||||
|
)
|
||||||
return replyMsg.send()
|
return replyMsg.send()
|
||||||
else:
|
else:
|
||||||
logger.info("暂且不处理")
|
logger.info("暂且不处理")
|
||||||
return "success"
|
return "success"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception(exc)
|
logger.exception(exc)
|
||||||
return exc
|
return exc
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
from config import conf
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
from config import conf
|
||||||
|
|
||||||
MAX_UTF8_LEN = 2048
|
MAX_UTF8_LEN = 2048
|
||||||
|
|
||||||
|
|
||||||
class WeChatAPIException(Exception):
|
class WeChatAPIException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -16,13 +18,13 @@ def verify_server(data):
|
|||||||
timestamp = data.timestamp
|
timestamp = data.timestamp
|
||||||
nonce = data.nonce
|
nonce = data.nonce
|
||||||
echostr = data.echostr
|
echostr = data.echostr
|
||||||
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写
|
token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写
|
||||||
|
|
||||||
data_list = [token, timestamp, nonce]
|
data_list = [token, timestamp, nonce]
|
||||||
data_list.sort()
|
data_list.sort()
|
||||||
sha1 = hashlib.sha1()
|
sha1 = hashlib.sha1()
|
||||||
# map(sha1.update, data_list) #python2
|
# map(sha1.update, data_list) #python2
|
||||||
sha1.update("".join(data_list).encode('utf-8'))
|
sha1.update("".join(data_list).encode("utf-8"))
|
||||||
hashcode = sha1.hexdigest()
|
hashcode = sha1.hexdigest()
|
||||||
print("handle/GET func: hashcode, signature: ", hashcode, signature)
|
print("handle/GET func: hashcode, signature: ", hashcode, signature)
|
||||||
if hashcode == signature:
|
if hashcode == signature:
|
||||||
@@ -32,9 +34,11 @@ def verify_server(data):
|
|||||||
except Exception as Argument:
|
except Exception as Argument:
|
||||||
return Argument
|
return Argument
|
||||||
|
|
||||||
|
|
||||||
def subscribe_msg():
|
def subscribe_msg():
|
||||||
trigger_prefix = conf().get('single_chat_prefix',[''])[0]
|
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
|
||||||
msg = textwrap.dedent(f"""\
|
msg = textwrap.dedent(
|
||||||
|
f"""\
|
||||||
感谢您的关注!
|
感谢您的关注!
|
||||||
这里是ChatGPT,可以自由对话。
|
这里是ChatGPT,可以自由对话。
|
||||||
资源有限,回复较慢,请勿着急。
|
资源有限,回复较慢,请勿着急。
|
||||||
@@ -42,22 +46,23 @@ def subscribe_msg():
|
|||||||
暂时不支持图片输入。
|
暂时不支持图片输入。
|
||||||
支持图片输出,画字开头的问题将回复图片链接。
|
支持图片输出,画字开头的问题将回复图片链接。
|
||||||
支持角色扮演和文字冒险两种定制模式对话。
|
支持角色扮演和文字冒险两种定制模式对话。
|
||||||
输入'{trigger_prefix}#帮助' 查看详细指令。""")
|
输入'{trigger_prefix}#帮助' 查看详细指令。"""
|
||||||
|
)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def split_string_by_utf8_length(string, max_length, max_split=0):
|
def split_string_by_utf8_length(string, max_length, max_split=0):
|
||||||
encoded = string.encode('utf-8')
|
encoded = string.encode("utf-8")
|
||||||
start, end = 0, 0
|
start, end = 0, 0
|
||||||
result = []
|
result = []
|
||||||
while end < len(encoded):
|
while end < len(encoded):
|
||||||
if max_split > 0 and len(result) >= max_split:
|
if max_split > 0 and len(result) >= max_split:
|
||||||
result.append(encoded[start:].decode('utf-8'))
|
result.append(encoded[start:].decode("utf-8"))
|
||||||
break
|
break
|
||||||
end = start + max_length
|
end = start + max_length
|
||||||
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
|
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
|
||||||
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
|
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
|
||||||
end -= 1
|
end -= 1
|
||||||
result.append(encoded[start:end].decode('utf-8'))
|
result.append(encoded[start:end].decode("utf-8"))
|
||||||
start = end
|
start = end
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-#
|
# -*- coding: utf-8 -*-#
|
||||||
# filename: receive.py
|
# filename: receive.py
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from channel.chat_message import ChatMessage
|
from channel.chat_message import ChatMessage
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
@@ -12,34 +13,35 @@ def parse_xml(web_data):
|
|||||||
xmlData = ET.fromstring(web_data)
|
xmlData = ET.fromstring(web_data)
|
||||||
return WeChatMPMessage(xmlData)
|
return WeChatMPMessage(xmlData)
|
||||||
|
|
||||||
|
|
||||||
class WeChatMPMessage(ChatMessage):
|
class WeChatMPMessage(ChatMessage):
|
||||||
def __init__(self, xmlData):
|
def __init__(self, xmlData):
|
||||||
super().__init__(xmlData)
|
super().__init__(xmlData)
|
||||||
self.to_user_id = xmlData.find('ToUserName').text
|
self.to_user_id = xmlData.find("ToUserName").text
|
||||||
self.from_user_id = xmlData.find('FromUserName').text
|
self.from_user_id = xmlData.find("FromUserName").text
|
||||||
self.create_time = xmlData.find('CreateTime').text
|
self.create_time = xmlData.find("CreateTime").text
|
||||||
self.msg_type = xmlData.find('MsgType').text
|
self.msg_type = xmlData.find("MsgType").text
|
||||||
try:
|
try:
|
||||||
self.msg_id = xmlData.find('MsgId').text
|
self.msg_id = xmlData.find("MsgId").text
|
||||||
except:
|
except:
|
||||||
self.msg_id = self.from_user_id+self.create_time
|
self.msg_id = self.from_user_id + self.create_time
|
||||||
self.is_group = False
|
self.is_group = False
|
||||||
|
|
||||||
# reply to other_user_id
|
# reply to other_user_id
|
||||||
self.other_user_id = self.from_user_id
|
self.other_user_id = self.from_user_id
|
||||||
|
|
||||||
if self.msg_type == 'text':
|
if self.msg_type == "text":
|
||||||
self.ctype = ContextType.TEXT
|
self.ctype = ContextType.TEXT
|
||||||
self.content = xmlData.find('Content').text.encode("utf-8")
|
self.content = xmlData.find("Content").text.encode("utf-8")
|
||||||
elif self.msg_type == 'voice':
|
elif self.msg_type == "voice":
|
||||||
self.ctype = ContextType.TEXT
|
self.ctype = ContextType.TEXT
|
||||||
self.content = xmlData.find('Recognition').text.encode("utf-8") # 接收语音识别结果
|
self.content = xmlData.find("Recognition").text.encode("utf-8") # 接收语音识别结果
|
||||||
elif self.msg_type == 'image':
|
elif self.msg_type == "image":
|
||||||
# not implemented
|
# not implemented
|
||||||
self.pic_url = xmlData.find('PicUrl').text
|
self.pic_url = xmlData.find("PicUrl").text
|
||||||
self.media_id = xmlData.find('MediaId').text
|
self.media_id = xmlData.find("MediaId").text
|
||||||
elif self.msg_type == 'event':
|
elif self.msg_type == "event":
|
||||||
self.content = xmlData.find('Event').text
|
self.content = xmlData.find("Event").text
|
||||||
else: # video, shortvideo, location, link
|
else: # video, shortvideo, location, link
|
||||||
# not implemented
|
# not implemented
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# filename: reply.py
|
# filename: reply.py
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
class Msg(object):
|
class Msg(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
@@ -9,13 +10,14 @@ class Msg(object):
|
|||||||
def send(self):
|
def send(self):
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
class TextMsg(Msg):
|
class TextMsg(Msg):
|
||||||
def __init__(self, toUserName, fromUserName, content):
|
def __init__(self, toUserName, fromUserName, content):
|
||||||
self.__dict = dict()
|
self.__dict = dict()
|
||||||
self.__dict['ToUserName'] = toUserName
|
self.__dict["ToUserName"] = toUserName
|
||||||
self.__dict['FromUserName'] = fromUserName
|
self.__dict["FromUserName"] = fromUserName
|
||||||
self.__dict['CreateTime'] = int(time.time())
|
self.__dict["CreateTime"] = int(time.time())
|
||||||
self.__dict['Content'] = content
|
self.__dict["Content"] = content
|
||||||
|
|
||||||
def send(self):
|
def send(self):
|
||||||
XmlForm = """
|
XmlForm = """
|
||||||
@@ -29,13 +31,14 @@ class TextMsg(Msg):
|
|||||||
"""
|
"""
|
||||||
return XmlForm.format(**self.__dict)
|
return XmlForm.format(**self.__dict)
|
||||||
|
|
||||||
|
|
||||||
class ImageMsg(Msg):
|
class ImageMsg(Msg):
|
||||||
def __init__(self, toUserName, fromUserName, mediaId):
|
def __init__(self, toUserName, fromUserName, mediaId):
|
||||||
self.__dict = dict()
|
self.__dict = dict()
|
||||||
self.__dict['ToUserName'] = toUserName
|
self.__dict["ToUserName"] = toUserName
|
||||||
self.__dict['FromUserName'] = fromUserName
|
self.__dict["FromUserName"] = fromUserName
|
||||||
self.__dict['CreateTime'] = int(time.time())
|
self.__dict["CreateTime"] = int(time.time())
|
||||||
self.__dict['MediaId'] = mediaId
|
self.__dict["MediaId"] = mediaId
|
||||||
|
|
||||||
def send(self):
|
def send(self):
|
||||||
XmlForm = """
|
XmlForm = """
|
||||||
@@ -49,4 +52,4 @@ class ImageMsg(Msg):
|
|||||||
</Image>
|
</Image>
|
||||||
</xml>
|
</xml>
|
||||||
"""
|
"""
|
||||||
return XmlForm.format(**self.__dict)
|
return XmlForm.format(**self.__dict)
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import web
|
|
||||||
import time
|
|
||||||
import json
|
import json
|
||||||
import requests
|
|
||||||
import threading
|
import threading
|
||||||
from common.singleton import singleton
|
import time
|
||||||
from common.log import logger
|
|
||||||
from common.expired_dict import ExpiredDict
|
import requests
|
||||||
from config import conf
|
import web
|
||||||
from bridge.reply import *
|
|
||||||
from bridge.context import *
|
from bridge.context import *
|
||||||
|
from bridge.reply import *
|
||||||
from channel.chat_channel import ChatChannel
|
from channel.chat_channel import ChatChannel
|
||||||
from channel.wechatmp.common import *
|
from channel.wechatmp.common import *
|
||||||
|
from common.expired_dict import ExpiredDict
|
||||||
|
from common.log import logger
|
||||||
|
from common.singleton import singleton
|
||||||
|
from config import conf
|
||||||
|
|
||||||
# If using SSL, uncomment the following lines, and modify the certificate path.
|
# If using SSL, uncomment the following lines, and modify the certificate path.
|
||||||
# from cheroot.server import HTTPServer
|
# from cheroot.server import HTTPServer
|
||||||
@@ -20,13 +22,14 @@ from channel.wechatmp.common import *
|
|||||||
# certificate='/ssl/cert.pem',
|
# certificate='/ssl/cert.pem',
|
||||||
# private_key='/ssl/cert.key')
|
# private_key='/ssl/cert.key')
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class WechatMPChannel(ChatChannel):
|
class WechatMPChannel(ChatChannel):
|
||||||
def __init__(self, passive_reply = True):
|
def __init__(self, passive_reply=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.passive_reply = passive_reply
|
self.passive_reply = passive_reply
|
||||||
self.running = set()
|
self.running = set()
|
||||||
self.received_msgs = ExpiredDict(60*60*24)
|
self.received_msgs = ExpiredDict(60 * 60 * 24)
|
||||||
if self.passive_reply:
|
if self.passive_reply:
|
||||||
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
|
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
|
||||||
self.cache_dict = dict()
|
self.cache_dict = dict()
|
||||||
@@ -36,8 +39,8 @@ class WechatMPChannel(ChatChannel):
|
|||||||
else:
|
else:
|
||||||
# TODO support image
|
# TODO support image
|
||||||
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
|
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
|
||||||
self.app_id = conf().get('wechatmp_app_id')
|
self.app_id = conf().get("wechatmp_app_id")
|
||||||
self.app_secret = conf().get('wechatmp_app_secret')
|
self.app_secret = conf().get("wechatmp_app_secret")
|
||||||
self.access_token = None
|
self.access_token = None
|
||||||
self.access_token_expires_time = 0
|
self.access_token_expires_time = 0
|
||||||
self.access_token_lock = threading.Lock()
|
self.access_token_lock = threading.Lock()
|
||||||
@@ -45,13 +48,12 @@ class WechatMPChannel(ChatChannel):
|
|||||||
|
|
||||||
def startup(self):
|
def startup(self):
|
||||||
if self.passive_reply:
|
if self.passive_reply:
|
||||||
urls = ('/wx', 'channel.wechatmp.SubscribeAccount.Query')
|
urls = ("/wx", "channel.wechatmp.SubscribeAccount.Query")
|
||||||
else:
|
else:
|
||||||
urls = ('/wx', 'channel.wechatmp.ServiceAccount.Query')
|
urls = ("/wx", "channel.wechatmp.ServiceAccount.Query")
|
||||||
app = web.application(urls, globals(), autoreload=False)
|
app = web.application(urls, globals(), autoreload=False)
|
||||||
port = conf().get('wechatmp_port', 8080)
|
port = conf().get("wechatmp_port", 8080)
|
||||||
web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port))
|
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||||
|
|
||||||
|
|
||||||
def wechatmp_request(self, method, url, **kwargs):
|
def wechatmp_request(self, method, url, **kwargs):
|
||||||
r = requests.request(method=method, url=url, **kwargs)
|
r = requests.request(method=method, url=url, **kwargs)
|
||||||
@@ -63,7 +65,6 @@ class WechatMPChannel(ChatChannel):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_access_token(self):
|
def get_access_token(self):
|
||||||
|
|
||||||
# return the access_token
|
# return the access_token
|
||||||
if self.access_token:
|
if self.access_token:
|
||||||
if self.access_token_expires_time - time.time() > 60:
|
if self.access_token_expires_time - time.time() > 60:
|
||||||
@@ -76,15 +77,15 @@ class WechatMPChannel(ChatChannel):
|
|||||||
# This happens every 2 hours, so it doesn't affect the experience very much
|
# This happens every 2 hours, so it doesn't affect the experience very much
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
self.access_token = None
|
self.access_token = None
|
||||||
url="https://api.weixin.qq.com/cgi-bin/token"
|
url = "https://api.weixin.qq.com/cgi-bin/token"
|
||||||
params={
|
params = {
|
||||||
"grant_type": "client_credential",
|
"grant_type": "client_credential",
|
||||||
"appid": self.app_id,
|
"appid": self.app_id,
|
||||||
"secret": self.app_secret
|
"secret": self.app_secret,
|
||||||
}
|
}
|
||||||
data = self.wechatmp_request(method='get', url=url, params=params)
|
data = self.wechatmp_request(method="get", url=url, params=params)
|
||||||
self.access_token = data['access_token']
|
self.access_token = data["access_token"]
|
||||||
self.access_token_expires_time = int(time.time()) + data['expires_in']
|
self.access_token_expires_time = int(time.time()) + data["expires_in"]
|
||||||
logger.info("[wechatmp] access_token: {}".format(self.access_token))
|
logger.info("[wechatmp] access_token: {}".format(self.access_token))
|
||||||
self.access_token_lock.release()
|
self.access_token_lock.release()
|
||||||
else:
|
else:
|
||||||
@@ -101,29 +102,37 @@ class WechatMPChannel(ChatChannel):
|
|||||||
else:
|
else:
|
||||||
receiver = context["receiver"]
|
receiver = context["receiver"]
|
||||||
reply_text = reply.content
|
reply_text = reply.content
|
||||||
url="https://api.weixin.qq.com/cgi-bin/message/custom/send"
|
url = "https://api.weixin.qq.com/cgi-bin/message/custom/send"
|
||||||
params = {
|
params = {"access_token": self.get_access_token()}
|
||||||
"access_token": self.get_access_token()
|
|
||||||
}
|
|
||||||
json_data = {
|
json_data = {
|
||||||
"touser": receiver,
|
"touser": receiver,
|
||||||
"msgtype": "text",
|
"msgtype": "text",
|
||||||
"text": {"content": reply_text}
|
"text": {"content": reply_text},
|
||||||
}
|
}
|
||||||
self.wechatmp_request(method='post', url=url, params=params, data=json.dumps(json_data, ensure_ascii=False).encode('utf8'))
|
self.wechatmp_request(
|
||||||
|
method="post",
|
||||||
|
url=url,
|
||||||
|
params=params,
|
||||||
|
data=json.dumps(json_data, ensure_ascii=False).encode("utf8"),
|
||||||
|
)
|
||||||
logger.info("[send] Do send to {}: {}".format(receiver, reply_text))
|
logger.info("[send] Do send to {}: {}".format(receiver, reply_text))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
|
||||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
|
logger.debug(
|
||||||
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id))
|
"[wechatmp] Success to generate reply, msgId={}".format(
|
||||||
|
context["msg"].msg_id
|
||||||
|
)
|
||||||
|
)
|
||||||
if self.passive_reply:
|
if self.passive_reply:
|
||||||
self.running.remove(session_id)
|
self.running.remove(session_id)
|
||||||
|
|
||||||
|
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
|
||||||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
|
logger.exception(
|
||||||
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception))
|
"[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(
|
||||||
|
context["msg"].msg_id, exception
|
||||||
|
)
|
||||||
|
)
|
||||||
if self.passive_reply:
|
if self.passive_reply:
|
||||||
assert session_id not in self.cache_dict
|
assert session_id not in self.cache_dict
|
||||||
self.running.remove(session_id)
|
self.running.remove(session_id)
|
||||||
|
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
OPEN_AI = "openAI"
|
OPEN_AI = "openAI"
|
||||||
CHATGPT = "chatGPT"
|
CHATGPT = "chatGPT"
|
||||||
BAIDU = "baidu"
|
BAIDU = "baidu"
|
||||||
CHATGPTONAZURE = "chatGPTOnAzure"
|
CHATGPTONAZURE = "chatGPTOnAzure"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
from queue import Full, Queue
|
from queue import Full, Queue
|
||||||
from time import monotonic as time
|
from time import monotonic as time
|
||||||
|
|
||||||
|
|
||||||
# add implementation of putleft to Queue
|
# add implementation of putleft to Queue
|
||||||
class Dequeue(Queue):
|
class Dequeue(Queue):
|
||||||
def putleft(self, item, block=True, timeout=None):
|
def putleft(self, item, block=True, timeout=None):
|
||||||
@@ -30,4 +30,4 @@ class Dequeue(Queue):
|
|||||||
return self.putleft(item, block=False)
|
return self.putleft(item, block=False)
|
||||||
|
|
||||||
def _putleft(self, item):
|
def _putleft(self, item):
|
||||||
self.queue.appendleft(item)
|
self.queue.appendleft(item)
|
||||||
|
|||||||
@@ -39,4 +39,4 @@ class ExpiredDict(dict):
|
|||||||
return [(key, self[key]) for key in self.keys()]
|
return [(key, self[key]) for key in self.keys()]
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self.keys().__iter__()
|
return self.keys().__iter__()
|
||||||
|
|||||||
@@ -10,20 +10,29 @@ def _reset_logger(log):
|
|||||||
log.handlers.clear()
|
log.handlers.clear()
|
||||||
log.propagate = False
|
log.propagate = False
|
||||||
console_handle = logging.StreamHandler(sys.stdout)
|
console_handle = logging.StreamHandler(sys.stdout)
|
||||||
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
console_handle.setFormatter(
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'))
|
logging.Formatter(
|
||||||
file_handle = logging.FileHandler('run.log', encoding='utf-8')
|
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
|
||||||
file_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'))
|
)
|
||||||
|
)
|
||||||
|
file_handle = logging.FileHandler("run.log", encoding="utf-8")
|
||||||
|
file_handle.setFormatter(
|
||||||
|
logging.Formatter(
|
||||||
|
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
)
|
||||||
log.addHandler(file_handle)
|
log.addHandler(file_handle)
|
||||||
log.addHandler(console_handle)
|
log.addHandler(console_handle)
|
||||||
|
|
||||||
|
|
||||||
def _get_logger():
|
def _get_logger():
|
||||||
log = logging.getLogger('log')
|
log = logging.getLogger("log")
|
||||||
_reset_logger(log)
|
_reset_logger(log)
|
||||||
log.setLevel(logging.INFO)
|
log.setLevel(logging.INFO)
|
||||||
return log
|
return log
|
||||||
|
|
||||||
|
|
||||||
# 日志句柄
|
# 日志句柄
|
||||||
logger = _get_logger()
|
logger = _get_logger()
|
||||||
|
|||||||
@@ -1,15 +1,20 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import pip
|
import pip
|
||||||
from pip._internal import main as pipmain
|
from pip._internal import main as pipmain
|
||||||
from common.log import logger,_reset_logger
|
|
||||||
|
from common.log import _reset_logger, logger
|
||||||
|
|
||||||
|
|
||||||
def install(package):
|
def install(package):
|
||||||
pipmain(['install', package])
|
pipmain(["install", package])
|
||||||
|
|
||||||
|
|
||||||
def install_requirements(file):
|
def install_requirements(file):
|
||||||
pipmain(['install', '-r', file, "--upgrade"])
|
pipmain(["install", "-r", file, "--upgrade"])
|
||||||
_reset_logger(logger)
|
_reset_logger(logger)
|
||||||
|
|
||||||
|
|
||||||
def check_dulwich():
|
def check_dulwich():
|
||||||
needwait = False
|
needwait = False
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
@@ -18,13 +23,14 @@ def check_dulwich():
|
|||||||
needwait = False
|
needwait = False
|
||||||
try:
|
try:
|
||||||
import dulwich
|
import dulwich
|
||||||
|
|
||||||
return
|
return
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
try:
|
||||||
install('dulwich')
|
install("dulwich")
|
||||||
except:
|
except:
|
||||||
needwait = True
|
needwait = True
|
||||||
try:
|
try:
|
||||||
import dulwich
|
import dulwich
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Unable to import dulwich")
|
raise ImportError("Unable to import dulwich")
|
||||||
|
|||||||
@@ -62,4 +62,4 @@ class SortedDict(dict):
|
|||||||
return iter(self.keys())
|
return iter(self.keys())
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})'
|
return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})"
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
import time,re,hashlib
|
import hashlib
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
import config
|
import config
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
|
||||||
|
|
||||||
def time_checker(f):
|
def time_checker(f):
|
||||||
def _time_checker(self, *args, **kwargs):
|
def _time_checker(self, *args, **kwargs):
|
||||||
_config = config.conf()
|
_config = config.conf()
|
||||||
@@ -9,17 +13,25 @@ def time_checker(f):
|
|||||||
if chat_time_module:
|
if chat_time_module:
|
||||||
chat_start_time = _config.get("chat_start_time", "00:00")
|
chat_start_time = _config.get("chat_start_time", "00:00")
|
||||||
chat_stopt_time = _config.get("chat_stop_time", "24:00")
|
chat_stopt_time = _config.get("chat_stop_time", "24:00")
|
||||||
time_regex = re.compile(r'^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$') #时间匹配,包含24:00
|
time_regex = re.compile(
|
||||||
|
r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$"
|
||||||
|
) # 时间匹配,包含24:00
|
||||||
|
|
||||||
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
|
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
|
||||||
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
|
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
|
||||||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
|
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
|
||||||
|
|
||||||
# 时间格式检查
|
# 时间格式检查
|
||||||
if not (starttime_format_check and stoptime_format_check and chat_time_check):
|
if not (
|
||||||
logger.warn('时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})'.format(starttime_format_check,stoptime_format_check))
|
starttime_format_check and stoptime_format_check and chat_time_check
|
||||||
if chat_start_time>"23:59":
|
):
|
||||||
logger.error('启动时间可能存在问题,请修改!')
|
logger.warn(
|
||||||
|
"时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(
|
||||||
|
starttime_format_check, stoptime_format_check
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if chat_start_time > "23:59":
|
||||||
|
logger.error("启动时间可能存在问题,请修改!")
|
||||||
|
|
||||||
# 服务时间检查
|
# 服务时间检查
|
||||||
now_time = time.strftime("%H:%M", time.localtime())
|
now_time = time.strftime("%H:%M", time.localtime())
|
||||||
@@ -27,12 +39,12 @@ def time_checker(f):
|
|||||||
f(self, *args, **kwargs)
|
f(self, *args, **kwargs)
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
if args[0]['Content'] == "#更新配置": # 不在服务时间内也可以更新配置
|
if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置
|
||||||
f(self, *args, **kwargs)
|
f(self, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
logger.info('非服务时间内,不接受访问')
|
logger.info("非服务时间内,不接受访问")
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
f(self, *args, **kwargs) # 未开启时间模块则直接回答
|
f(self, *args, **kwargs) # 未开启时间模块则直接回答
|
||||||
return _time_checker
|
|
||||||
|
|
||||||
|
return _time_checker
|
||||||
|
|||||||
@@ -1,20 +1,18 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from config import conf
|
from config import conf
|
||||||
|
|
||||||
|
|
||||||
class TmpDir(object):
|
class TmpDir(object):
|
||||||
"""A temporary directory that is deleted when the object is destroyed.
|
"""A temporary directory that is deleted when the object is destroyed."""
|
||||||
"""
|
|
||||||
|
tmpFilePath = pathlib.Path("./tmp/")
|
||||||
|
|
||||||
tmpFilePath = pathlib.Path('./tmp/')
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pathExists = os.path.exists(self.tmpFilePath)
|
pathExists = os.path.exists(self.tmpFilePath)
|
||||||
if not pathExists:
|
if not pathExists:
|
||||||
os.makedirs(self.tmpFilePath)
|
os.makedirs(self.tmpFilePath)
|
||||||
|
|
||||||
def path(self):
|
def path(self):
|
||||||
return str(self.tmpFilePath) + '/'
|
return str(self.tmpFilePath) + "/"
|
||||||
|
|
||||||
|
|||||||
@@ -2,16 +2,30 @@
|
|||||||
"open_ai_api_key": "YOUR API KEY",
|
"open_ai_api_key": "YOUR API KEY",
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
"proxy": "",
|
"proxy": "",
|
||||||
"single_chat_prefix": ["bot", "@bot"],
|
"single_chat_prefix": [
|
||||||
|
"bot",
|
||||||
|
"@bot"
|
||||||
|
],
|
||||||
"single_chat_reply_prefix": "[bot] ",
|
"single_chat_reply_prefix": "[bot] ",
|
||||||
"group_chat_prefix": ["@bot"],
|
"group_chat_prefix": [
|
||||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"],
|
"@bot"
|
||||||
"group_chat_in_one_session": ["ChatGPT测试群"],
|
],
|
||||||
"image_create_prefix": ["画", "看", "找"],
|
"group_name_white_list": [
|
||||||
|
"ChatGPT测试群",
|
||||||
|
"ChatGPT测试群2"
|
||||||
|
],
|
||||||
|
"group_chat_in_one_session": [
|
||||||
|
"ChatGPT测试群"
|
||||||
|
],
|
||||||
|
"image_create_prefix": [
|
||||||
|
"画",
|
||||||
|
"看",
|
||||||
|
"找"
|
||||||
|
],
|
||||||
"speech_recognition": false,
|
"speech_recognition": false,
|
||||||
"group_speech_recognition": false,
|
"group_speech_recognition": false,
|
||||||
"voice_reply_voice": false,
|
"voice_reply_voice": false,
|
||||||
"conversation_max_tokens": 1000,
|
"conversation_max_tokens": 1000,
|
||||||
"expires_in_seconds": 3600,
|
"expires_in_seconds": 3600,
|
||||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"
|
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"
|
||||||
}
|
}
|
||||||
|
|||||||
53
config.py
53
config.py
@@ -3,9 +3,10 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from common.log import logger
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
from common.log import logger
|
||||||
|
|
||||||
# 将所有可用的配置项写在字典里, 请使用小写字母
|
# 将所有可用的配置项写在字典里, 请使用小写字母
|
||||||
available_setting = {
|
available_setting = {
|
||||||
# openai api配置
|
# openai api配置
|
||||||
@@ -16,8 +17,7 @@ available_setting = {
|
|||||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
||||||
"azure_deployment_id": "", #azure 模型部署名称
|
"azure_deployment_id": "", # azure 模型部署名称
|
||||||
|
|
||||||
# Bot触发配置
|
# Bot触发配置
|
||||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||||
@@ -30,25 +30,21 @@ available_setting = {
|
|||||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||||
"trigger_by_self": False, # 是否允许机器人触发
|
"trigger_by_self": False, # 是否允许机器人触发
|
||||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
|
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
|
||||||
|
|
||||||
# chatgpt会话参数
|
# chatgpt会话参数
|
||||||
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
||||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||||
|
|
||||||
# chatgpt限流配置
|
# chatgpt限流配置
|
||||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
|
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
|
||||||
"rate_limit_dalle": 50, # openai dalle的调用频率限制
|
"rate_limit_dalle": 50, # openai dalle的调用频率限制
|
||||||
|
|
||||||
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
|
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
|
||||||
"temperature": 0.9,
|
"temperature": 0.9,
|
||||||
"top_p": 1,
|
"top_p": 1,
|
||||||
"frequency_penalty": 0,
|
"frequency_penalty": 0,
|
||||||
"presence_penalty": 0,
|
"presence_penalty": 0,
|
||||||
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||||
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试
|
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试
|
||||||
|
|
||||||
# 语音设置
|
# 语音设置
|
||||||
"speech_recognition": False, # 是否开启语音识别
|
"speech_recognition": False, # 是否开启语音识别
|
||||||
"group_speech_recognition": False, # 是否开启群组语音识别
|
"group_speech_recognition": False, # 是否开启群组语音识别
|
||||||
@@ -56,50 +52,40 @@ available_setting = {
|
|||||||
"always_reply_voice": False, # 是否一直使用语音回复
|
"always_reply_voice": False, # 是否一直使用语音回复
|
||||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure
|
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure
|
||||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure
|
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure
|
||||||
|
|
||||||
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
|
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
|
||||||
"baidu_app_id": "",
|
"baidu_app_id": "",
|
||||||
"baidu_api_key": "",
|
"baidu_api_key": "",
|
||||||
"baidu_secret_key": "",
|
"baidu_secret_key": "",
|
||||||
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
|
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
|
||||||
"baidu_dev_pid": "1536",
|
"baidu_dev_pid": "1536",
|
||||||
|
|
||||||
# azure 语音api配置, 使用azure语音识别和语音合成时需要
|
# azure 语音api配置, 使用azure语音识别和语音合成时需要
|
||||||
"azure_voice_api_key": "",
|
"azure_voice_api_key": "",
|
||||||
"azure_voice_region": "japaneast",
|
"azure_voice_region": "japaneast",
|
||||||
|
|
||||||
# 服务时间限制,目前支持itchat
|
# 服务时间限制,目前支持itchat
|
||||||
"chat_time_module": False, # 是否开启服务时间限制
|
"chat_time_module": False, # 是否开启服务时间限制
|
||||||
"chat_start_time": "00:00", # 服务开始时间
|
"chat_start_time": "00:00", # 服务开始时间
|
||||||
"chat_stop_time": "24:00", # 服务结束时间
|
"chat_stop_time": "24:00", # 服务结束时间
|
||||||
|
|
||||||
# itchat的配置
|
# itchat的配置
|
||||||
"hot_reload": False, # 是否开启热重载
|
"hot_reload": False, # 是否开启热重载
|
||||||
|
|
||||||
# wechaty的配置
|
# wechaty的配置
|
||||||
"wechaty_puppet_service_token": "", # wechaty的token
|
"wechaty_puppet_service_token": "", # wechaty的token
|
||||||
|
|
||||||
# wechatmp的配置
|
# wechatmp的配置
|
||||||
"wechatmp_token": "", # 微信公众平台的Token
|
"wechatmp_token": "", # 微信公众平台的Token
|
||||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
|
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
|
||||||
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要
|
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要
|
||||||
|
|
||||||
# chatgpt指令自定义触发词
|
# chatgpt指令自定义触发词
|
||||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令,必须以#开头
|
"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头
|
||||||
|
|
||||||
# channel配置
|
# channel配置
|
||||||
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service}
|
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service}
|
||||||
|
|
||||||
"debug": False, # 是否开启debug模式,开启后会打印更多日志
|
"debug": False, # 是否开启debug模式,开启后会打印更多日志
|
||||||
|
|
||||||
# 插件配置
|
# 插件配置
|
||||||
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
|
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Config(dict):
|
class Config(dict):
|
||||||
def __init__(self, d:dict={}):
|
def __init__(self, d: dict = {}):
|
||||||
super().__init__(d)
|
super().__init__(d)
|
||||||
# user_datas: 用户数据,key为用户名,value为用户数据,也是dict
|
# user_datas: 用户数据,key为用户名,value为用户数据,也是dict
|
||||||
self.user_datas = {}
|
self.user_datas = {}
|
||||||
@@ -130,7 +116,7 @@ class Config(dict):
|
|||||||
|
|
||||||
def load_user_datas(self):
|
def load_user_datas(self):
|
||||||
try:
|
try:
|
||||||
with open('user_datas.pkl', 'rb') as f:
|
with open("user_datas.pkl", "rb") as f:
|
||||||
self.user_datas = pickle.load(f)
|
self.user_datas = pickle.load(f)
|
||||||
logger.info("[Config] User datas loaded.")
|
logger.info("[Config] User datas loaded.")
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
@@ -141,12 +127,13 @@ class Config(dict):
|
|||||||
|
|
||||||
def save_user_datas(self):
|
def save_user_datas(self):
|
||||||
try:
|
try:
|
||||||
with open('user_datas.pkl', 'wb') as f:
|
with open("user_datas.pkl", "wb") as f:
|
||||||
pickle.dump(self.user_datas, f)
|
pickle.dump(self.user_datas, f)
|
||||||
logger.info("[Config] User datas saved.")
|
logger.info("[Config] User datas saved.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("[Config] User datas error: {}".format(e))
|
logger.info("[Config] User datas error: {}".format(e))
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
|
|
||||||
|
|
||||||
@@ -154,7 +141,7 @@ def load_config():
|
|||||||
global config
|
global config
|
||||||
config_path = "./config.json"
|
config_path = "./config.json"
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
logger.info('配置文件不存在,将使用config-template.json模板')
|
logger.info("配置文件不存在,将使用config-template.json模板")
|
||||||
config_path = "./config-template.json"
|
config_path = "./config-template.json"
|
||||||
|
|
||||||
config_str = read_file(config_path)
|
config_str = read_file(config_path)
|
||||||
@@ -169,7 +156,8 @@ def load_config():
|
|||||||
name = name.lower()
|
name = name.lower()
|
||||||
if name in available_setting:
|
if name in available_setting:
|
||||||
logger.info(
|
logger.info(
|
||||||
"[INIT] override config by environ args: {}={}".format(name, value))
|
"[INIT] override config by environ args: {}={}".format(name, value)
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
config[name] = eval(value)
|
config[name] = eval(value)
|
||||||
except:
|
except:
|
||||||
@@ -182,18 +170,19 @@ def load_config():
|
|||||||
|
|
||||||
if config.get("debug", False):
|
if config.get("debug", False):
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
logger.debug("[INIT] set log level to DEBUG")
|
logger.debug("[INIT] set log level to DEBUG")
|
||||||
|
|
||||||
logger.info("[INIT] load config: {}".format(config))
|
logger.info("[INIT] load config: {}".format(config))
|
||||||
|
|
||||||
config.load_user_datas()
|
config.load_user_datas()
|
||||||
|
|
||||||
|
|
||||||
def get_root():
|
def get_root():
|
||||||
return os.path.dirname(os.path.abspath(__file__))
|
return os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
def read_file(path):
|
def read_file(path):
|
||||||
with open(path, mode='r', encoding='utf-8') as f:
|
with open(path, mode="r", encoding="utf-8") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ ADD ./entrypoint.sh /entrypoint.sh
|
|||||||
RUN chmod +x /entrypoint.sh \
|
RUN chmod +x /entrypoint.sh \
|
||||||
&& groupadd -r noroot \
|
&& groupadd -r noroot \
|
||||||
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
|
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
|
||||||
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
||||||
|
|
||||||
USER noroot
|
USER noroot
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ RUN apt-get update \
|
|||||||
&& pip install --no-cache -r requirements.txt \
|
&& pip install --no-cache -r requirements.txt \
|
||||||
&& pip install --no-cache -r requirements-optional.txt \
|
&& pip install --no-cache -r requirements-optional.txt \
|
||||||
&& pip install azure-cognitiveservices-speech
|
&& pip install azure-cognitiveservices-speech
|
||||||
|
|
||||||
WORKDIR ${BUILD_PREFIX}
|
WORKDIR ${BUILD_PREFIX}
|
||||||
|
|
||||||
ADD docker/entrypoint.sh /entrypoint.sh
|
ADD docker/entrypoint.sh /entrypoint.sh
|
||||||
|
|||||||
@@ -11,6 +11,5 @@ docker build -f Dockerfile.alpine \
|
|||||||
-t zhayujie/chatgpt-on-wechat .
|
-t zhayujie/chatgpt-on-wechat .
|
||||||
|
|
||||||
# tag image
|
# tag image
|
||||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine
|
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine
|
||||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine
|
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine
|
||||||
|
|
||||||
@@ -11,5 +11,5 @@ docker build -f Dockerfile.debian \
|
|||||||
-t zhayujie/chatgpt-on-wechat .
|
-t zhayujie/chatgpt-on-wechat .
|
||||||
|
|
||||||
# tag image
|
# tag image
|
||||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian
|
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian
|
||||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian
|
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian
|
||||||
@@ -9,7 +9,7 @@ RUN apk add --no-cache \
|
|||||||
ffmpeg \
|
ffmpeg \
|
||||||
espeak \
|
espeak \
|
||||||
&& pip install --no-cache \
|
&& pip install --no-cache \
|
||||||
baidu-aip \
|
baidu-aip \
|
||||||
chardet \
|
chardet \
|
||||||
SpeechRecognition
|
SpeechRecognition
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ RUN apt-get update \
|
|||||||
ffmpeg \
|
ffmpeg \
|
||||||
espeak \
|
espeak \
|
||||||
&& pip install --no-cache \
|
&& pip install --no-cache \
|
||||||
baidu-aip \
|
baidu-aip \
|
||||||
chardet \
|
chardet \
|
||||||
SpeechRecognition
|
SpeechRecognition
|
||||||
|
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ run_d:
|
|||||||
docker rm $(CONTAINER_NAME) || echo
|
docker rm $(CONTAINER_NAME) || echo
|
||||||
docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \
|
docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \
|
||||||
--env-file=$(DOTENV) \
|
--env-file=$(DOTENV) \
|
||||||
$(MOUNT) $(IMG)
|
$(MOUNT) $(IMG)
|
||||||
|
|
||||||
run_i:
|
run_i:
|
||||||
docker rm $(CONTAINER_NAME) || echo
|
docker rm $(CONTAINER_NAME) || echo
|
||||||
docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \
|
docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \
|
||||||
--env-file=$(DOTENV) \
|
--env-file=$(DOTENV) \
|
||||||
$(MOUNT) $(IMG)
|
$(MOUNT) $(IMG)
|
||||||
|
|
||||||
stop:
|
stop:
|
||||||
docker stop $(CONTAINER_NAME)
|
docker stop $(CONTAINER_NAME)
|
||||||
|
|||||||
@@ -24,17 +24,17 @@
|
|||||||
在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。
|
在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。
|
||||||
|
|
||||||
- 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。
|
- 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。
|
||||||
|
|
||||||
- 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。
|
- 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。
|
||||||
|
|
||||||
安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。
|
安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。
|
||||||
|
|
||||||
- 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui
|
- 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui
|
||||||
|
|
||||||
- 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git
|
- 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git
|
||||||
|
|
||||||
在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。
|
在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。
|
||||||
|
|
||||||
安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。
|
安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。
|
||||||
|
|
||||||
## 插件化实现
|
## 插件化实现
|
||||||
@@ -107,14 +107,14 @@
|
|||||||
```
|
```
|
||||||
|
|
||||||
回复`Reply`的定义如下所示,它允许Bot可以回复多类不同的消息。同时也加入了`INFO`和`ERROR`消息类型区分系统提示和系统错误。
|
回复`Reply`的定义如下所示,它允许Bot可以回复多类不同的消息。同时也加入了`INFO`和`ERROR`消息类型区分系统提示和系统错误。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class ReplyType(Enum):
|
class ReplyType(Enum):
|
||||||
TEXT = 1 # 文本
|
TEXT = 1 # 文本
|
||||||
VOICE = 2 # 音频文件
|
VOICE = 2 # 音频文件
|
||||||
IMAGE = 3 # 图片文件
|
IMAGE = 3 # 图片文件
|
||||||
IMAGE_URL = 4 # 图片URL
|
IMAGE_URL = 4 # 图片URL
|
||||||
|
|
||||||
INFO = 9
|
INFO = 9
|
||||||
ERROR = 10
|
ERROR = 10
|
||||||
class Reply:
|
class Reply:
|
||||||
@@ -159,12 +159,12 @@
|
|||||||
|
|
||||||
目前支持三类触发事件:
|
目前支持三类触发事件:
|
||||||
```
|
```
|
||||||
1.收到消息
|
1.收到消息
|
||||||
---> `ON_HANDLE_CONTEXT`
|
---> `ON_HANDLE_CONTEXT`
|
||||||
2.产生回复
|
2.产生回复
|
||||||
---> `ON_DECORATE_REPLY`
|
---> `ON_DECORATE_REPLY`
|
||||||
3.装饰回复
|
3.装饰回复
|
||||||
---> `ON_SEND_REPLY`
|
---> `ON_SEND_REPLY`
|
||||||
4.发送回复
|
4.发送回复
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -268,6 +268,6 @@ class Hello(Plugin):
|
|||||||
- 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。
|
- 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。
|
||||||
|
|
||||||
在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。
|
在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。
|
||||||
|
|
||||||
- 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。
|
- 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。
|
||||||
- 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。
|
- 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from .plugin_manager import PluginManager
|
|
||||||
from .event import *
|
from .event import *
|
||||||
from .plugin import *
|
from .plugin import *
|
||||||
|
from .plugin_manager import PluginManager
|
||||||
|
|
||||||
instance = PluginManager()
|
instance = PluginManager()
|
||||||
|
|
||||||
register = instance.register
|
register = instance.register
|
||||||
# load_plugins = instance.load_plugins
|
# load_plugins = instance.load_plugins
|
||||||
# emit_event = instance.emit_event
|
# emit_event = instance.emit_event
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from .banwords import *
|
from .banwords import *
|
||||||
|
|||||||
@@ -2,56 +2,67 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import plugins
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
import plugins
|
|
||||||
from plugins import *
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from plugins import *
|
||||||
|
|
||||||
from .lib.WordsSearch import WordsSearch
|
from .lib.WordsSearch import WordsSearch
|
||||||
|
|
||||||
|
|
||||||
@plugins.register(name="Banwords", desire_priority=100, hidden=True, desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent")
|
@plugins.register(
|
||||||
|
name="Banwords",
|
||||||
|
desire_priority=100,
|
||||||
|
hidden=True,
|
||||||
|
desc="判断消息中是否有敏感词、决定是否回复。",
|
||||||
|
version="1.0",
|
||||||
|
author="lanvent",
|
||||||
|
)
|
||||||
class Banwords(Plugin):
|
class Banwords(Plugin):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
try:
|
try:
|
||||||
curdir=os.path.dirname(__file__)
|
curdir = os.path.dirname(__file__)
|
||||||
config_path=os.path.join(curdir,"config.json")
|
config_path = os.path.join(curdir, "config.json")
|
||||||
conf=None
|
conf = None
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
conf={"action":"ignore"}
|
conf = {"action": "ignore"}
|
||||||
with open(config_path,"w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(conf,f,indent=4)
|
json.dump(conf, f, indent=4)
|
||||||
else:
|
else:
|
||||||
with open(config_path,"r") as f:
|
with open(config_path, "r") as f:
|
||||||
conf=json.load(f)
|
conf = json.load(f)
|
||||||
self.searchr = WordsSearch()
|
self.searchr = WordsSearch()
|
||||||
self.action = conf["action"]
|
self.action = conf["action"]
|
||||||
banwords_path = os.path.join(curdir,"banwords.txt")
|
banwords_path = os.path.join(curdir, "banwords.txt")
|
||||||
with open(banwords_path, 'r', encoding='utf-8') as f:
|
with open(banwords_path, "r", encoding="utf-8") as f:
|
||||||
words=[]
|
words = []
|
||||||
for line in f:
|
for line in f:
|
||||||
word = line.strip()
|
word = line.strip()
|
||||||
if word:
|
if word:
|
||||||
words.append(word)
|
words.append(word)
|
||||||
self.searchr.SetKeywords(words)
|
self.searchr.SetKeywords(words)
|
||||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||||
if conf.get("reply_filter",True):
|
if conf.get("reply_filter", True):
|
||||||
self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply
|
self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply
|
||||||
self.reply_action = conf.get("reply_action","ignore")
|
self.reply_action = conf.get("reply_action", "ignore")
|
||||||
logger.info("[Banwords] inited")
|
logger.info("[Banwords] inited")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
|
logger.warn(
|
||||||
|
"[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ."
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def on_handle_context(self, e_context: EventContext):
|
def on_handle_context(self, e_context: EventContext):
|
||||||
|
if e_context["context"].type not in [
|
||||||
if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]:
|
ContextType.TEXT,
|
||||||
|
ContextType.IMAGE_CREATE,
|
||||||
|
]:
|
||||||
return
|
return
|
||||||
|
|
||||||
content = e_context['context'].content
|
content = e_context["context"].content
|
||||||
logger.debug("[Banwords] on_handle_context. content: %s" % content)
|
logger.debug("[Banwords] on_handle_context. content: %s" % content)
|
||||||
if self.action == "ignore":
|
if self.action == "ignore":
|
||||||
f = self.searchr.FindFirst(content)
|
f = self.searchr.FindFirst(content)
|
||||||
@@ -61,31 +72,34 @@ class Banwords(Plugin):
|
|||||||
return
|
return
|
||||||
elif self.action == "replace":
|
elif self.action == "replace":
|
||||||
if self.searchr.ContainsAny(content):
|
if self.searchr.ContainsAny(content):
|
||||||
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content))
|
reply = Reply(
|
||||||
e_context['reply'] = reply
|
ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)
|
||||||
|
)
|
||||||
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
return
|
return
|
||||||
|
|
||||||
def on_decorate_reply(self, e_context: EventContext):
|
|
||||||
|
|
||||||
if e_context['reply'].type not in [ReplyType.TEXT]:
|
def on_decorate_reply(self, e_context: EventContext):
|
||||||
|
if e_context["reply"].type not in [ReplyType.TEXT]:
|
||||||
return
|
return
|
||||||
|
|
||||||
reply = e_context['reply']
|
reply = e_context["reply"]
|
||||||
content = reply.content
|
content = reply.content
|
||||||
if self.reply_action == "ignore":
|
if self.reply_action == "ignore":
|
||||||
f = self.searchr.FindFirst(content)
|
f = self.searchr.FindFirst(content)
|
||||||
if f:
|
if f:
|
||||||
logger.info("[Banwords] %s in reply" % f["Keyword"])
|
logger.info("[Banwords] %s in reply" % f["Keyword"])
|
||||||
e_context['reply'] = None
|
e_context["reply"] = None
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
return
|
return
|
||||||
elif self.reply_action == "replace":
|
elif self.reply_action == "replace":
|
||||||
if self.searchr.ContainsAny(content):
|
if self.searchr.ContainsAny(content):
|
||||||
reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n"+self.searchr.Replace(content))
|
reply = Reply(
|
||||||
e_context['reply'] = reply
|
ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)
|
||||||
|
)
|
||||||
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.CONTINUE
|
e_context.action = EventAction.CONTINUE
|
||||||
return
|
return
|
||||||
|
|
||||||
def get_help_text(self, **kwargs):
|
def get_help_text(self, **kwargs):
|
||||||
return Banwords.desc
|
return Banwords.desc
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"action": "replace",
|
"action": "replace",
|
||||||
"reply_filter": true,
|
"reply_filter": true,
|
||||||
"reply_action": "ignore"
|
"reply_action": "ignore"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087
|
|||||||
``` json
|
``` json
|
||||||
{
|
{
|
||||||
"service_id": "s...", #"机器人ID"
|
"service_id": "s...", #"机器人ID"
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"secret_key": ""
|
"secret_key": ""
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@@ -1 +1 @@
|
|||||||
from .bdunit import *
|
from .bdunit import *
|
||||||
|
|||||||
@@ -2,21 +2,29 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
from uuid import getnode as get_mac
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
import plugins
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
import plugins
|
|
||||||
from plugins import *
|
from plugins import *
|
||||||
from uuid import getnode as get_mac
|
|
||||||
|
|
||||||
|
|
||||||
"""利用百度UNIT实现智能对话
|
"""利用百度UNIT实现智能对话
|
||||||
如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
|
如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@plugins.register(name="BDunit", desire_priority=0, hidden=True, desc="Baidu unit bot system", version="0.1", author="jackson")
|
@plugins.register(
|
||||||
|
name="BDunit",
|
||||||
|
desire_priority=0,
|
||||||
|
hidden=True,
|
||||||
|
desc="Baidu unit bot system",
|
||||||
|
version="0.1",
|
||||||
|
author="jackson",
|
||||||
|
)
|
||||||
class BDunit(Plugin):
|
class BDunit(Plugin):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -40,11 +48,10 @@ class BDunit(Plugin):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
def on_handle_context(self, e_context: EventContext):
|
def on_handle_context(self, e_context: EventContext):
|
||||||
|
if e_context["context"].type != ContextType.TEXT:
|
||||||
if e_context['context'].type != ContextType.TEXT:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
content = e_context['context'].content
|
content = e_context["context"].content
|
||||||
logger.debug("[BDunit] on_handle_context. content: %s" % content)
|
logger.debug("[BDunit] on_handle_context. content: %s" % content)
|
||||||
parsed = self.getUnit2(content)
|
parsed = self.getUnit2(content)
|
||||||
intent = self.getIntent(parsed)
|
intent = self.getIntent(parsed)
|
||||||
@@ -53,7 +60,7 @@ class BDunit(Plugin):
|
|||||||
reply = Reply()
|
reply = Reply()
|
||||||
reply.type = ReplyType.TEXT
|
reply.type = ReplyType.TEXT
|
||||||
reply.content = self.getSay(parsed)
|
reply.content = self.getSay(parsed)
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||||
else:
|
else:
|
||||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||||
@@ -70,17 +77,15 @@ class BDunit(Plugin):
|
|||||||
string: access_token
|
string: access_token
|
||||||
"""
|
"""
|
||||||
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(
|
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(
|
||||||
self.api_key, self.secret_key)
|
self.api_key, self.secret_key
|
||||||
|
)
|
||||||
payload = ""
|
payload = ""
|
||||||
headers = {
|
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Accept': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.request("POST", url, headers=headers, data=payload)
|
response = requests.request("POST", url, headers=headers, data=payload)
|
||||||
|
|
||||||
# print(response.text)
|
# print(response.text)
|
||||||
return response.json()['access_token']
|
return response.json()["access_token"]
|
||||||
|
|
||||||
def getUnit(self, query):
|
def getUnit(self, query):
|
||||||
"""
|
"""
|
||||||
@@ -90,11 +95,14 @@ class BDunit(Plugin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
url = (
|
url = (
|
||||||
'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token='
|
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
|
||||||
+ self.access_token
|
+ self.access_token
|
||||||
)
|
)
|
||||||
request = {"query": query, "user_id": str(
|
request = {
|
||||||
get_mac())[:32], "terminal_id": "88888"}
|
"query": query,
|
||||||
|
"user_id": str(get_mac())[:32],
|
||||||
|
"terminal_id": "88888",
|
||||||
|
}
|
||||||
body = {
|
body = {
|
||||||
"log_id": str(uuid.uuid1()),
|
"log_id": str(uuid.uuid1()),
|
||||||
"version": "3.0",
|
"version": "3.0",
|
||||||
@@ -142,11 +150,7 @@ class BDunit(Plugin):
|
|||||||
:param parsed: UNIT 解析结果
|
:param parsed: UNIT 解析结果
|
||||||
:returns: 意图数组
|
:returns: 意图数组
|
||||||
"""
|
"""
|
||||||
if (
|
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||||
parsed
|
|
||||||
and "result" in parsed
|
|
||||||
and "response_list" in parsed["result"]
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
return parsed["result"]["response_list"][0]["schema"]["intent"]
|
return parsed["result"]["response_list"][0]["schema"]["intent"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -163,11 +167,7 @@ class BDunit(Plugin):
|
|||||||
:param intent: 意图的名称
|
:param intent: 意图的名称
|
||||||
:returns: True: 包含; False: 不包含
|
:returns: True: 包含; False: 不包含
|
||||||
"""
|
"""
|
||||||
if (
|
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||||
parsed
|
|
||||||
and "result" in parsed
|
|
||||||
and "response_list" in parsed["result"]
|
|
||||||
):
|
|
||||||
response_list = parsed["result"]["response_list"]
|
response_list = parsed["result"]["response_list"]
|
||||||
for response in response_list:
|
for response in response_list:
|
||||||
if (
|
if (
|
||||||
@@ -189,11 +189,7 @@ class BDunit(Plugin):
|
|||||||
:returns: 词槽列表。你可以通过 name 属性筛选词槽,
|
:returns: 词槽列表。你可以通过 name 属性筛选词槽,
|
||||||
再通过 normalized_word 属性取出相应的值
|
再通过 normalized_word 属性取出相应的值
|
||||||
"""
|
"""
|
||||||
if (
|
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||||
parsed
|
|
||||||
and "result" in parsed
|
|
||||||
and "response_list" in parsed["result"]
|
|
||||||
):
|
|
||||||
response_list = parsed["result"]["response_list"]
|
response_list = parsed["result"]["response_list"]
|
||||||
if intent == "":
|
if intent == "":
|
||||||
try:
|
try:
|
||||||
@@ -236,11 +232,7 @@ class BDunit(Plugin):
|
|||||||
:param parsed: UNIT 解析结果
|
:param parsed: UNIT 解析结果
|
||||||
:returns: UNIT 的回复文本
|
:returns: UNIT 的回复文本
|
||||||
"""
|
"""
|
||||||
if (
|
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||||
parsed
|
|
||||||
and "result" in parsed
|
|
||||||
and "response_list" in parsed["result"]
|
|
||||||
):
|
|
||||||
response_list = parsed["result"]["response_list"]
|
response_list = parsed["result"]["response_list"]
|
||||||
answer = {}
|
answer = {}
|
||||||
for response in response_list:
|
for response in response_list:
|
||||||
@@ -266,11 +258,7 @@ class BDunit(Plugin):
|
|||||||
:param intent: 意图的名称
|
:param intent: 意图的名称
|
||||||
:returns: UNIT 的回复文本
|
:returns: UNIT 的回复文本
|
||||||
"""
|
"""
|
||||||
if (
|
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||||
parsed
|
|
||||||
and "result" in parsed
|
|
||||||
and "response_list" in parsed["result"]
|
|
||||||
):
|
|
||||||
response_list = parsed["result"]["response_list"]
|
response_list = parsed["result"]["response_list"]
|
||||||
if intent == "":
|
if intent == "":
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"service_id": "s...",
|
"service_id": "s...",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"secret_key": ""
|
"secret_key": ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from .dungeon import *
|
from .dungeon import *
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
|
import plugins
|
||||||
from bridge.bridge import Bridge
|
from bridge.bridge import Bridge
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from common.expired_dict import ExpiredDict
|
|
||||||
from config import conf
|
|
||||||
import plugins
|
|
||||||
from plugins import *
|
|
||||||
from common.log import logger
|
|
||||||
from common import const
|
from common import const
|
||||||
|
from common.expired_dict import ExpiredDict
|
||||||
|
from common.log import logger
|
||||||
|
from config import conf
|
||||||
|
from plugins import *
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/bupticybee/ChineseAiDungeonChatGPT
|
# https://github.com/bupticybee/ChineseAiDungeonChatGPT
|
||||||
class StoryTeller():
|
class StoryTeller:
|
||||||
def __init__(self, bot, sessionid, story):
|
def __init__(self, bot, sessionid, story):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.sessionid = sessionid
|
self.sessionid = sessionid
|
||||||
@@ -27,67 +28,85 @@ class StoryTeller():
|
|||||||
if user_action[-1] != "。":
|
if user_action[-1] != "。":
|
||||||
user_action = user_action + "。"
|
user_action = user_action + "。"
|
||||||
if self.first_interact:
|
if self.first_interact:
|
||||||
prompt = """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。
|
prompt = (
|
||||||
开头是,""" + self.story + " " + user_action
|
"""现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。
|
||||||
|
开头是,"""
|
||||||
|
+ self.story
|
||||||
|
+ " "
|
||||||
|
+ user_action
|
||||||
|
)
|
||||||
self.first_interact = False
|
self.first_interact = False
|
||||||
else:
|
else:
|
||||||
prompt = """继续,一次只需要续写四到六句话,总共就只讲5分钟内发生的事情。""" + user_action
|
prompt = """继续,一次只需要续写四到六句话,总共就只讲5分钟内发生的事情。""" + user_action
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
@plugins.register(name="Dungeon", desire_priority=0, namecn="文字冒险", desc="A plugin to play dungeon game", version="1.0", author="lanvent")
|
@plugins.register(
|
||||||
|
name="Dungeon",
|
||||||
|
desire_priority=0,
|
||||||
|
namecn="文字冒险",
|
||||||
|
desc="A plugin to play dungeon game",
|
||||||
|
version="1.0",
|
||||||
|
author="lanvent",
|
||||||
|
)
|
||||||
class Dungeon(Plugin):
|
class Dungeon(Plugin):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||||
logger.info("[Dungeon] inited")
|
logger.info("[Dungeon] inited")
|
||||||
# 目前没有设计session过期事件,这里先暂时使用过期字典
|
# 目前没有设计session过期事件,这里先暂时使用过期字典
|
||||||
if conf().get('expires_in_seconds'):
|
if conf().get("expires_in_seconds"):
|
||||||
self.games = ExpiredDict(conf().get('expires_in_seconds'))
|
self.games = ExpiredDict(conf().get("expires_in_seconds"))
|
||||||
else:
|
else:
|
||||||
self.games = dict()
|
self.games = dict()
|
||||||
|
|
||||||
def on_handle_context(self, e_context: EventContext):
|
def on_handle_context(self, e_context: EventContext):
|
||||||
|
if e_context["context"].type != ContextType.TEXT:
|
||||||
if e_context['context'].type != ContextType.TEXT:
|
|
||||||
return
|
return
|
||||||
bottype = Bridge().get_bot_type("chat")
|
bottype = Bridge().get_bot_type("chat")
|
||||||
if bottype not in (const.CHATGPT, const.OPEN_AI):
|
if bottype not in (const.CHATGPT, const.OPEN_AI):
|
||||||
return
|
return
|
||||||
bot = Bridge().get_bot("chat")
|
bot = Bridge().get_bot("chat")
|
||||||
content = e_context['context'].content[:]
|
content = e_context["context"].content[:]
|
||||||
clist = e_context['context'].content.split(maxsplit=1)
|
clist = e_context["context"].content.split(maxsplit=1)
|
||||||
sessionid = e_context['context']['session_id']
|
sessionid = e_context["context"]["session_id"]
|
||||||
logger.debug("[Dungeon] on_handle_context. content: %s" % clist)
|
logger.debug("[Dungeon] on_handle_context. content: %s" % clist)
|
||||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||||
if clist[0] == f"{trigger_prefix}停止冒险":
|
if clist[0] == f"{trigger_prefix}停止冒险":
|
||||||
if sessionid in self.games:
|
if sessionid in self.games:
|
||||||
self.games[sessionid].reset()
|
self.games[sessionid].reset()
|
||||||
del self.games[sessionid]
|
del self.games[sessionid]
|
||||||
reply = Reply(ReplyType.INFO, "冒险结束!")
|
reply = Reply(ReplyType.INFO, "冒险结束!")
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games:
|
elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games:
|
||||||
if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险":
|
if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险":
|
||||||
if len(clist)>1 :
|
if len(clist) > 1:
|
||||||
story = clist[1]
|
story = clist[1]
|
||||||
else:
|
else:
|
||||||
story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
|
story = (
|
||||||
|
"你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
|
||||||
|
)
|
||||||
self.games[sessionid] = StoryTeller(bot, sessionid, story)
|
self.games[sessionid] = StoryTeller(bot, sessionid, story)
|
||||||
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
|
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||||
else:
|
else:
|
||||||
prompt = self.games[sessionid].action(content)
|
prompt = self.games[sessionid].action(content)
|
||||||
e_context['context'].type = ContextType.TEXT
|
e_context["context"].type = ContextType.TEXT
|
||||||
e_context['context'].content = prompt
|
e_context["context"].content = prompt
|
||||||
e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑
|
e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑
|
||||||
|
|
||||||
def get_help_text(self, **kwargs):
|
def get_help_text(self, **kwargs):
|
||||||
help_text = "可以和机器人一起玩文字冒险游戏。\n"
|
help_text = "可以和机器人一起玩文字冒险游戏。\n"
|
||||||
if kwargs.get('verbose') != True:
|
if kwargs.get("verbose") != True:
|
||||||
return help_text
|
return help_text
|
||||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||||
help_text = f"{trigger_prefix}开始冒险 "+"背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"+f"{trigger_prefix}停止冒险: 结束游戏。\n"
|
help_text = (
|
||||||
if kwargs.get('verbose') == True:
|
f"{trigger_prefix}开始冒险 "
|
||||||
|
+ "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"
|
||||||
|
+ f"{trigger_prefix}停止冒险: 结束游戏。\n"
|
||||||
|
)
|
||||||
|
if kwargs.get("verbose") == True:
|
||||||
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
|
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
|
||||||
return help_text
|
return help_text
|
||||||
|
|||||||
@@ -9,17 +9,17 @@ class Event(Enum):
|
|||||||
e_context = { "channel": 消息channel, "context" : 本次消息的context}
|
e_context = { "channel": 消息channel, "context" : 本次消息的context}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ON_HANDLE_CONTEXT = 2 # 处理消息前
|
ON_HANDLE_CONTEXT = 2 # 处理消息前
|
||||||
"""
|
"""
|
||||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 }
|
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 }
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰
|
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰
|
||||||
"""
|
"""
|
||||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
|
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ON_SEND_REPLY = 4 # 发送回复前
|
ON_SEND_REPLY = 4 # 发送回复前
|
||||||
"""
|
"""
|
||||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
|
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
|
||||||
"""
|
"""
|
||||||
@@ -28,9 +28,9 @@ class Event(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class EventAction(Enum):
|
class EventAction(Enum):
|
||||||
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
|
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
|
||||||
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
|
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
|
||||||
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑
|
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑
|
||||||
|
|
||||||
|
|
||||||
class EventContext:
|
class EventContext:
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from .finish import *
|
from .finish import *
|
||||||
|
|||||||
@@ -1,14 +1,21 @@
|
|||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
|
import plugins
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from config import conf
|
|
||||||
import plugins
|
|
||||||
from plugins import *
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from config import conf
|
||||||
|
from plugins import *
|
||||||
|
|
||||||
|
|
||||||
@plugins.register(name="Finish", desire_priority=-999, hidden=True, desc="A plugin that check unknown command", version="1.0", author="js00000")
|
@plugins.register(
|
||||||
|
name="Finish",
|
||||||
|
desire_priority=-999,
|
||||||
|
hidden=True,
|
||||||
|
desc="A plugin that check unknown command",
|
||||||
|
version="1.0",
|
||||||
|
author="js00000",
|
||||||
|
)
|
||||||
class Finish(Plugin):
|
class Finish(Plugin):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -16,19 +23,18 @@ class Finish(Plugin):
|
|||||||
logger.info("[Finish] inited")
|
logger.info("[Finish] inited")
|
||||||
|
|
||||||
def on_handle_context(self, e_context: EventContext):
|
def on_handle_context(self, e_context: EventContext):
|
||||||
|
if e_context["context"].type != ContextType.TEXT:
|
||||||
if e_context['context'].type != ContextType.TEXT:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
content = e_context['context'].content
|
content = e_context["context"].content
|
||||||
logger.debug("[Finish] on_handle_context. content: %s" % content)
|
logger.debug("[Finish] on_handle_context. content: %s" % content)
|
||||||
trigger_prefix = conf().get('plugin_trigger_prefix',"$")
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||||
if content.startswith(trigger_prefix):
|
if content.startswith(trigger_prefix):
|
||||||
reply = Reply()
|
reply = Reply()
|
||||||
reply.type = ReplyType.ERROR
|
reply.type = ReplyType.ERROR
|
||||||
reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n"
|
reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n"
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||||
|
|
||||||
def get_help_text(self, **kwargs):
|
def get_help_text(self, **kwargs):
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from .godcmd import *
|
from .godcmd import *
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
{
|
{
|
||||||
"password": "",
|
"password": "",
|
||||||
"admin_users": []
|
"admin_users": []
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,14 +6,16 @@ import random
|
|||||||
import string
|
import string
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import plugins
|
||||||
from bridge.bridge import Bridge
|
from bridge.bridge import Bridge
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from config import conf, load_config
|
|
||||||
import plugins
|
|
||||||
from plugins import *
|
|
||||||
from common import const
|
from common import const
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from config import conf, load_config
|
||||||
|
from plugins import *
|
||||||
|
|
||||||
# 定义指令集
|
# 定义指令集
|
||||||
COMMANDS = {
|
COMMANDS = {
|
||||||
"help": {
|
"help": {
|
||||||
@@ -41,7 +43,7 @@ COMMANDS = {
|
|||||||
},
|
},
|
||||||
"id": {
|
"id": {
|
||||||
"alias": ["id", "用户"],
|
"alias": ["id", "用户"],
|
||||||
"desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员
|
"desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员
|
||||||
},
|
},
|
||||||
"reset": {
|
"reset": {
|
||||||
"alias": ["reset", "重置会话"],
|
"alias": ["reset", "重置会话"],
|
||||||
@@ -114,18 +116,20 @@ ADMIN_COMMANDS = {
|
|||||||
"desc": "开启机器调试日志",
|
"desc": "开启机器调试日志",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 定义帮助函数
|
# 定义帮助函数
|
||||||
def get_help_text(isadmin, isgroup):
|
def get_help_text(isadmin, isgroup):
|
||||||
help_text = "通用指令:\n"
|
help_text = "通用指令:\n"
|
||||||
for cmd, info in COMMANDS.items():
|
for cmd, info in COMMANDS.items():
|
||||||
if cmd=="auth": #不提示认证指令
|
if cmd == "auth": # 不提示认证指令
|
||||||
continue
|
continue
|
||||||
if cmd=="id" and conf().get("channel_type","wx") not in ["wxy","wechatmp"]:
|
if cmd == "id" and conf().get("channel_type", "wx") not in ["wxy", "wechatmp"]:
|
||||||
continue
|
continue
|
||||||
alias=["#"+a for a in info['alias'][:1]]
|
alias = ["#" + a for a in info["alias"][:1]]
|
||||||
help_text += f"{','.join(alias)} "
|
help_text += f"{','.join(alias)} "
|
||||||
if 'args' in info:
|
if "args" in info:
|
||||||
args=[a for a in info['args']]
|
args = [a for a in info["args"]]
|
||||||
help_text += f"{' '.join(args)}"
|
help_text += f"{' '.join(args)}"
|
||||||
help_text += f": {info['desc']}\n"
|
help_text += f": {info['desc']}\n"
|
||||||
|
|
||||||
@@ -135,39 +139,48 @@ def get_help_text(isadmin, isgroup):
|
|||||||
for plugin in plugins:
|
for plugin in plugins:
|
||||||
if plugins[plugin].enabled and not plugins[plugin].hidden:
|
if plugins[plugin].enabled and not plugins[plugin].hidden:
|
||||||
namecn = plugins[plugin].namecn
|
namecn = plugins[plugin].namecn
|
||||||
help_text += "\n%s:"%namecn
|
help_text += "\n%s:" % namecn
|
||||||
help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()
|
help_text += (
|
||||||
|
PluginManager().instances[plugin].get_help_text(verbose=False).strip()
|
||||||
|
)
|
||||||
|
|
||||||
if ADMIN_COMMANDS and isadmin:
|
if ADMIN_COMMANDS and isadmin:
|
||||||
help_text += "\n\n管理员指令:\n"
|
help_text += "\n\n管理员指令:\n"
|
||||||
for cmd, info in ADMIN_COMMANDS.items():
|
for cmd, info in ADMIN_COMMANDS.items():
|
||||||
alias=["#"+a for a in info['alias'][:1]]
|
alias = ["#" + a for a in info["alias"][:1]]
|
||||||
help_text += f"{','.join(alias)} "
|
help_text += f"{','.join(alias)} "
|
||||||
if 'args' in info:
|
if "args" in info:
|
||||||
args=[a for a in info['args']]
|
args = [a for a in info["args"]]
|
||||||
help_text += f"{' '.join(args)}"
|
help_text += f"{' '.join(args)}"
|
||||||
help_text += f": {info['desc']}\n"
|
help_text += f": {info['desc']}\n"
|
||||||
return help_text
|
return help_text
|
||||||
|
|
||||||
@plugins.register(name="Godcmd", desire_priority=999, hidden=True, desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent")
|
|
||||||
class Godcmd(Plugin):
|
|
||||||
|
|
||||||
|
@plugins.register(
|
||||||
|
name="Godcmd",
|
||||||
|
desire_priority=999,
|
||||||
|
hidden=True,
|
||||||
|
desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证",
|
||||||
|
version="1.0",
|
||||||
|
author="lanvent",
|
||||||
|
)
|
||||||
|
class Godcmd(Plugin):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
curdir=os.path.dirname(__file__)
|
curdir = os.path.dirname(__file__)
|
||||||
config_path=os.path.join(curdir,"config.json")
|
config_path = os.path.join(curdir, "config.json")
|
||||||
gconf=None
|
gconf = None
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
gconf={"password":"","admin_users":[]}
|
gconf = {"password": "", "admin_users": []}
|
||||||
with open(config_path,"w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(gconf,f,indent=4)
|
json.dump(gconf, f, indent=4)
|
||||||
else:
|
else:
|
||||||
with open(config_path,"r") as f:
|
with open(config_path, "r") as f:
|
||||||
gconf=json.load(f)
|
gconf = json.load(f)
|
||||||
if gconf["password"] == "":
|
if gconf["password"] == "":
|
||||||
self.temp_password = "".join(random.sample(string.digits, 4))
|
self.temp_password = "".join(random.sample(string.digits, 4))
|
||||||
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。"%self.temp_password)
|
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。" % self.temp_password)
|
||||||
else:
|
else:
|
||||||
self.temp_password = None
|
self.temp_password = None
|
||||||
custom_commands = conf().get("clear_memory_commands", [])
|
custom_commands = conf().get("clear_memory_commands", [])
|
||||||
@@ -178,41 +191,42 @@ class Godcmd(Plugin):
|
|||||||
COMMANDS["reset"]["alias"].append(custom_command)
|
COMMANDS["reset"]["alias"].append(custom_command)
|
||||||
|
|
||||||
self.password = gconf["password"]
|
self.password = gconf["password"]
|
||||||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
|
self.admin_users = gconf[
|
||||||
self.isrunning = True # 机器人是否运行中
|
"admin_users"
|
||||||
|
] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
|
||||||
|
self.isrunning = True # 机器人是否运行中
|
||||||
|
|
||||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||||
logger.info("[Godcmd] inited")
|
logger.info("[Godcmd] inited")
|
||||||
|
|
||||||
|
|
||||||
def on_handle_context(self, e_context: EventContext):
|
def on_handle_context(self, e_context: EventContext):
|
||||||
context_type = e_context['context'].type
|
context_type = e_context["context"].type
|
||||||
if context_type != ContextType.TEXT:
|
if context_type != ContextType.TEXT:
|
||||||
if not self.isrunning:
|
if not self.isrunning:
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
return
|
return
|
||||||
|
|
||||||
content = e_context['context'].content
|
content = e_context["context"].content
|
||||||
logger.debug("[Godcmd] on_handle_context. content: %s" % content)
|
logger.debug("[Godcmd] on_handle_context. content: %s" % content)
|
||||||
if content.startswith("#"):
|
if content.startswith("#"):
|
||||||
# msg = e_context['context']['msg']
|
# msg = e_context['context']['msg']
|
||||||
channel = e_context['channel']
|
channel = e_context["channel"]
|
||||||
user = e_context['context']['receiver']
|
user = e_context["context"]["receiver"]
|
||||||
session_id = e_context['context']['session_id']
|
session_id = e_context["context"]["session_id"]
|
||||||
isgroup = e_context['context'].get("isgroup", False)
|
isgroup = e_context["context"].get("isgroup", False)
|
||||||
bottype = Bridge().get_bot_type("chat")
|
bottype = Bridge().get_bot_type("chat")
|
||||||
bot = Bridge().get_bot("chat")
|
bot = Bridge().get_bot("chat")
|
||||||
# 将命令和参数分割
|
# 将命令和参数分割
|
||||||
command_parts = content[1:].strip().split()
|
command_parts = content[1:].strip().split()
|
||||||
cmd = command_parts[0]
|
cmd = command_parts[0]
|
||||||
args = command_parts[1:]
|
args = command_parts[1:]
|
||||||
isadmin=False
|
isadmin = False
|
||||||
if user in self.admin_users:
|
if user in self.admin_users:
|
||||||
isadmin=True
|
isadmin = True
|
||||||
ok=False
|
ok = False
|
||||||
result="string"
|
result = "string"
|
||||||
if any(cmd in info['alias'] for info in COMMANDS.values()):
|
if any(cmd in info["alias"] for info in COMMANDS.values()):
|
||||||
cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias'])
|
cmd = next(c for c, info in COMMANDS.items() if cmd in info["alias"])
|
||||||
if cmd == "auth":
|
if cmd == "auth":
|
||||||
ok, result = self.authenticate(user, args, isadmin, isgroup)
|
ok, result = self.authenticate(user, args, isadmin, isgroup)
|
||||||
elif cmd == "help" or cmd == "helpp":
|
elif cmd == "help" or cmd == "helpp":
|
||||||
@@ -224,10 +238,14 @@ class Godcmd(Plugin):
|
|||||||
query_name = args[0].upper()
|
query_name = args[0].upper()
|
||||||
# search name and namecn
|
# search name and namecn
|
||||||
for name, plugincls in plugins.items():
|
for name, plugincls in plugins.items():
|
||||||
if not plugincls.enabled :
|
if not plugincls.enabled:
|
||||||
continue
|
continue
|
||||||
if query_name == name or query_name == plugincls.namecn:
|
if query_name == name or query_name == plugincls.namecn:
|
||||||
ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True)
|
ok, result = True, PluginManager().instances[
|
||||||
|
name
|
||||||
|
].get_help_text(
|
||||||
|
isgroup=isgroup, isadmin=isadmin, verbose=True
|
||||||
|
)
|
||||||
break
|
break
|
||||||
if not ok:
|
if not ok:
|
||||||
result = "插件不存在或未启用"
|
result = "插件不存在或未启用"
|
||||||
@@ -236,14 +254,14 @@ class Godcmd(Plugin):
|
|||||||
elif cmd == "set_openai_api_key":
|
elif cmd == "set_openai_api_key":
|
||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
user_data = conf().get_user_data(user)
|
user_data = conf().get_user_data(user)
|
||||||
user_data['openai_api_key'] = args[0]
|
user_data["openai_api_key"] = args[0]
|
||||||
ok, result = True, "你的OpenAI私有api_key已设置为" + args[0]
|
ok, result = True, "你的OpenAI私有api_key已设置为" + args[0]
|
||||||
else:
|
else:
|
||||||
ok, result = False, "请提供一个api_key"
|
ok, result = False, "请提供一个api_key"
|
||||||
elif cmd == "reset_openai_api_key":
|
elif cmd == "reset_openai_api_key":
|
||||||
try:
|
try:
|
||||||
user_data = conf().get_user_data(user)
|
user_data = conf().get_user_data(user)
|
||||||
user_data.pop('openai_api_key')
|
user_data.pop("openai_api_key")
|
||||||
ok, result = True, "你的OpenAI私有api_key已清除"
|
ok, result = True, "你的OpenAI私有api_key已清除"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ok, result = False, "你没有设置私有api_key"
|
ok, result = False, "你没有设置私有api_key"
|
||||||
@@ -255,12 +273,16 @@ class Godcmd(Plugin):
|
|||||||
else:
|
else:
|
||||||
ok, result = False, "当前对话机器人不支持重置会话"
|
ok, result = False, "当前对话机器人不支持重置会话"
|
||||||
logger.debug("[Godcmd] command: %s by %s" % (cmd, user))
|
logger.debug("[Godcmd] command: %s by %s" % (cmd, user))
|
||||||
elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()):
|
elif any(cmd in info["alias"] for info in ADMIN_COMMANDS.values()):
|
||||||
if isadmin:
|
if isadmin:
|
||||||
if isgroup:
|
if isgroup:
|
||||||
ok, result = False, "群聊不可执行管理员指令"
|
ok, result = False, "群聊不可执行管理员指令"
|
||||||
else:
|
else:
|
||||||
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias'])
|
cmd = next(
|
||||||
|
c
|
||||||
|
for c, info in ADMIN_COMMANDS.items()
|
||||||
|
if cmd in info["alias"]
|
||||||
|
)
|
||||||
if cmd == "stop":
|
if cmd == "stop":
|
||||||
self.isrunning = False
|
self.isrunning = False
|
||||||
ok, result = True, "服务已暂停"
|
ok, result = True, "服务已暂停"
|
||||||
@@ -278,13 +300,13 @@ class Godcmd(Plugin):
|
|||||||
else:
|
else:
|
||||||
ok, result = False, "当前对话机器人不支持重置会话"
|
ok, result = False, "当前对话机器人不支持重置会话"
|
||||||
elif cmd == "debug":
|
elif cmd == "debug":
|
||||||
logger.setLevel('DEBUG')
|
logger.setLevel("DEBUG")
|
||||||
ok, result = True, "DEBUG模式已开启"
|
ok, result = True, "DEBUG模式已开启"
|
||||||
elif cmd == "plist":
|
elif cmd == "plist":
|
||||||
plugins = PluginManager().list_plugins()
|
plugins = PluginManager().list_plugins()
|
||||||
ok = True
|
ok = True
|
||||||
result = "插件列表:\n"
|
result = "插件列表:\n"
|
||||||
for name,plugincls in plugins.items():
|
for name, plugincls in plugins.items():
|
||||||
result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - "
|
result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - "
|
||||||
if plugincls.enabled:
|
if plugincls.enabled:
|
||||||
result += "已启用\n"
|
result += "已启用\n"
|
||||||
@@ -294,16 +316,20 @@ class Godcmd(Plugin):
|
|||||||
new_plugins = PluginManager().scan_plugins()
|
new_plugins = PluginManager().scan_plugins()
|
||||||
ok, result = True, "插件扫描完成"
|
ok, result = True, "插件扫描完成"
|
||||||
PluginManager().activate_plugins()
|
PluginManager().activate_plugins()
|
||||||
if len(new_plugins) >0 :
|
if len(new_plugins) > 0:
|
||||||
result += "\n发现新插件:\n"
|
result += "\n发现新插件:\n"
|
||||||
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
|
result += "\n".join(
|
||||||
else :
|
[f"{p.name}_v{p.version}" for p in new_plugins]
|
||||||
result +=", 未发现新插件"
|
)
|
||||||
|
else:
|
||||||
|
result += ", 未发现新插件"
|
||||||
elif cmd == "setpri":
|
elif cmd == "setpri":
|
||||||
if len(args) != 2:
|
if len(args) != 2:
|
||||||
ok, result = False, "请提供插件名和优先级"
|
ok, result = False, "请提供插件名和优先级"
|
||||||
else:
|
else:
|
||||||
ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
|
ok = PluginManager().set_plugin_priority(
|
||||||
|
args[0], int(args[1])
|
||||||
|
)
|
||||||
if ok:
|
if ok:
|
||||||
result = "插件" + args[0] + "优先级已设置为" + args[1]
|
result = "插件" + args[0] + "优先级已设置为" + args[1]
|
||||||
else:
|
else:
|
||||||
@@ -350,42 +376,42 @@ class Godcmd(Plugin):
|
|||||||
else:
|
else:
|
||||||
ok, result = False, "需要管理员权限才能执行该指令"
|
ok, result = False, "需要管理员权限才能执行该指令"
|
||||||
else:
|
else:
|
||||||
trigger_prefix = conf().get('plugin_trigger_prefix',"$")
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||||
if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交
|
if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交
|
||||||
return
|
return
|
||||||
ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n"
|
ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n"
|
||||||
|
|
||||||
reply = Reply()
|
reply = Reply()
|
||||||
if ok:
|
if ok:
|
||||||
reply.type = ReplyType.INFO
|
reply.type = ReplyType.INFO
|
||||||
else:
|
else:
|
||||||
reply.type = ReplyType.ERROR
|
reply.type = ReplyType.ERROR
|
||||||
reply.content = result
|
reply.content = result
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
|
|
||||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||||
elif not self.isrunning:
|
elif not self.isrunning:
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
|
|
||||||
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] :
|
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool, str]:
|
||||||
if isgroup:
|
if isgroup:
|
||||||
return False,"请勿在群聊中认证"
|
return False, "请勿在群聊中认证"
|
||||||
|
|
||||||
if isadmin:
|
if isadmin:
|
||||||
return False,"管理员账号无需认证"
|
return False, "管理员账号无需认证"
|
||||||
|
|
||||||
if len(args) != 1:
|
if len(args) != 1:
|
||||||
return False,"请提供口令"
|
return False, "请提供口令"
|
||||||
|
|
||||||
password = args[0]
|
password = args[0]
|
||||||
if password == self.password:
|
if password == self.password:
|
||||||
self.admin_users.append(userid)
|
self.admin_users.append(userid)
|
||||||
return True,"认证成功"
|
return True, "认证成功"
|
||||||
elif password == self.temp_password:
|
elif password == self.temp_password:
|
||||||
self.admin_users.append(userid)
|
self.admin_users.append(userid)
|
||||||
return True,"认证成功,请尽快设置口令"
|
return True, "认证成功,请尽快设置口令"
|
||||||
else:
|
else:
|
||||||
return False,"认证失败"
|
return False, "认证失败"
|
||||||
|
|
||||||
def get_help_text(self, isadmin = False, isgroup = False, **kwargs):
|
def get_help_text(self, isadmin=False, isgroup=False, **kwargs):
|
||||||
return get_help_text(isadmin, isgroup)
|
return get_help_text(isadmin, isgroup)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from .hello import *
|
from .hello import *
|
||||||
|
|||||||
@@ -1,14 +1,21 @@
|
|||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
|
import plugins
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from channel.chat_message import ChatMessage
|
from channel.chat_message import ChatMessage
|
||||||
import plugins
|
|
||||||
from plugins import *
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from plugins import *
|
||||||
|
|
||||||
|
|
||||||
@plugins.register(name="Hello", desire_priority=-1, hidden=True, desc="A simple plugin that says hello", version="0.1", author="lanvent")
|
@plugins.register(
|
||||||
|
name="Hello",
|
||||||
|
desire_priority=-1,
|
||||||
|
hidden=True,
|
||||||
|
desc="A simple plugin that says hello",
|
||||||
|
version="0.1",
|
||||||
|
author="lanvent",
|
||||||
|
)
|
||||||
class Hello(Plugin):
|
class Hello(Plugin):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -16,33 +23,34 @@ class Hello(Plugin):
|
|||||||
logger.info("[Hello] inited")
|
logger.info("[Hello] inited")
|
||||||
|
|
||||||
def on_handle_context(self, e_context: EventContext):
|
def on_handle_context(self, e_context: EventContext):
|
||||||
|
if e_context["context"].type != ContextType.TEXT:
|
||||||
if e_context['context'].type != ContextType.TEXT:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
content = e_context['context'].content
|
content = e_context["context"].content
|
||||||
logger.debug("[Hello] on_handle_context. content: %s" % content)
|
logger.debug("[Hello] on_handle_context. content: %s" % content)
|
||||||
if content == "Hello":
|
if content == "Hello":
|
||||||
reply = Reply()
|
reply = Reply()
|
||||||
reply.type = ReplyType.TEXT
|
reply.type = ReplyType.TEXT
|
||||||
msg:ChatMessage = e_context['context']['msg']
|
msg: ChatMessage = e_context["context"]["msg"]
|
||||||
if e_context['context']['isgroup']:
|
if e_context["context"]["isgroup"]:
|
||||||
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
|
reply.content = (
|
||||||
|
f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
reply.content = f"Hello, {msg.from_user_nickname}"
|
reply.content = f"Hello, {msg.from_user_nickname}"
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||||
|
|
||||||
if content == "Hi":
|
if content == "Hi":
|
||||||
reply = Reply()
|
reply = Reply()
|
||||||
reply.type = ReplyType.TEXT
|
reply.type = ReplyType.TEXT
|
||||||
reply.content = "Hi"
|
reply.content = "Hi"
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply
|
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply
|
||||||
|
|
||||||
if content == "End":
|
if content == "End":
|
||||||
# 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World"
|
# 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World"
|
||||||
e_context['context'].type = ContextType.IMAGE_CREATE
|
e_context["context"].type = ContextType.IMAGE_CREATE
|
||||||
content = "The World"
|
content = "The World"
|
||||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||||
|
|
||||||
|
|||||||
@@ -3,4 +3,4 @@ class Plugin:
|
|||||||
self.handlers = {}
|
self.handlers = {}
|
||||||
|
|
||||||
def get_help_text(self, **kwargs):
|
def get_help_text(self, **kwargs):
|
||||||
return "暂无帮助信息"
|
return "暂无帮助信息"
|
||||||
|
|||||||
@@ -5,17 +5,19 @@ import importlib.util
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from common.log import logger
|
||||||
from common.singleton import singleton
|
from common.singleton import singleton
|
||||||
from common.sorted_dict import SortedDict
|
from common.sorted_dict import SortedDict
|
||||||
from .event import *
|
|
||||||
from common.log import logger
|
|
||||||
from config import conf
|
from config import conf
|
||||||
|
|
||||||
|
from .event import *
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class PluginManager:
|
class PluginManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.plugins = SortedDict(lambda k,v: v.priority,reverse=True)
|
self.plugins = SortedDict(lambda k, v: v.priority, reverse=True)
|
||||||
self.listening_plugins = {}
|
self.listening_plugins = {}
|
||||||
self.instances = {}
|
self.instances = {}
|
||||||
self.pconf = {}
|
self.pconf = {}
|
||||||
@@ -26,17 +28,27 @@ class PluginManager:
|
|||||||
def wrapper(plugincls):
|
def wrapper(plugincls):
|
||||||
plugincls.name = name
|
plugincls.name = name
|
||||||
plugincls.priority = desire_priority
|
plugincls.priority = desire_priority
|
||||||
plugincls.desc = kwargs.get('desc')
|
plugincls.desc = kwargs.get("desc")
|
||||||
plugincls.author = kwargs.get('author')
|
plugincls.author = kwargs.get("author")
|
||||||
plugincls.path = self.current_plugin_path
|
plugincls.path = self.current_plugin_path
|
||||||
plugincls.version = kwargs.get('version') if kwargs.get('version') != None else "1.0"
|
plugincls.version = (
|
||||||
plugincls.namecn = kwargs.get('namecn') if kwargs.get('namecn') != None else name
|
kwargs.get("version") if kwargs.get("version") != None else "1.0"
|
||||||
plugincls.hidden = kwargs.get('hidden') if kwargs.get('hidden') != None else False
|
)
|
||||||
|
plugincls.namecn = (
|
||||||
|
kwargs.get("namecn") if kwargs.get("namecn") != None else name
|
||||||
|
)
|
||||||
|
plugincls.hidden = (
|
||||||
|
kwargs.get("hidden") if kwargs.get("hidden") != None else False
|
||||||
|
)
|
||||||
plugincls.enabled = True
|
plugincls.enabled = True
|
||||||
if self.current_plugin_path == None:
|
if self.current_plugin_path == None:
|
||||||
raise Exception("Plugin path not set")
|
raise Exception("Plugin path not set")
|
||||||
self.plugins[name.upper()] = plugincls
|
self.plugins[name.upper()] = plugincls
|
||||||
logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
|
logger.info(
|
||||||
|
"Plugin %s_v%s registered, path=%s"
|
||||||
|
% (name, plugincls.version, plugincls.path)
|
||||||
|
)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def save_config(self):
|
def save_config(self):
|
||||||
@@ -50,10 +62,12 @@ class PluginManager:
|
|||||||
if os.path.exists("./plugins/plugins.json"):
|
if os.path.exists("./plugins/plugins.json"):
|
||||||
with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
|
with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
|
||||||
pconf = json.load(f)
|
pconf = json.load(f)
|
||||||
pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True)
|
pconf["plugins"] = SortedDict(
|
||||||
|
lambda k, v: v["priority"], pconf["plugins"], reverse=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
modified = True
|
modified = True
|
||||||
pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)}
|
pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)}
|
||||||
self.pconf = pconf
|
self.pconf = pconf
|
||||||
if modified:
|
if modified:
|
||||||
self.save_config()
|
self.save_config()
|
||||||
@@ -67,7 +81,7 @@ class PluginManager:
|
|||||||
plugin_path = os.path.join(plugins_dir, plugin_name)
|
plugin_path = os.path.join(plugins_dir, plugin_name)
|
||||||
if os.path.isdir(plugin_path):
|
if os.path.isdir(plugin_path):
|
||||||
# 判断插件是否包含同名__init__.py文件
|
# 判断插件是否包含同名__init__.py文件
|
||||||
main_module_path = os.path.join(plugin_path,"__init__.py")
|
main_module_path = os.path.join(plugin_path, "__init__.py")
|
||||||
if os.path.isfile(main_module_path):
|
if os.path.isfile(main_module_path):
|
||||||
# 导入插件
|
# 导入插件
|
||||||
import_path = "plugins.{}".format(plugin_name)
|
import_path = "plugins.{}".format(plugin_name)
|
||||||
@@ -76,16 +90,26 @@ class PluginManager:
|
|||||||
if plugin_path in self.loaded:
|
if plugin_path in self.loaded:
|
||||||
if self.loaded[plugin_path] == None:
|
if self.loaded[plugin_path] == None:
|
||||||
logger.info("reload module %s" % plugin_name)
|
logger.info("reload module %s" % plugin_name)
|
||||||
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
|
self.loaded[plugin_path] = importlib.reload(
|
||||||
dependent_module_names = [name for name in sys.modules.keys() if name.startswith( import_path+ '.')]
|
sys.modules[import_path]
|
||||||
|
)
|
||||||
|
dependent_module_names = [
|
||||||
|
name
|
||||||
|
for name in sys.modules.keys()
|
||||||
|
if name.startswith(import_path + ".")
|
||||||
|
]
|
||||||
for name in dependent_module_names:
|
for name in dependent_module_names:
|
||||||
logger.info("reload module %s" % name)
|
logger.info("reload module %s" % name)
|
||||||
importlib.reload(sys.modules[name])
|
importlib.reload(sys.modules[name])
|
||||||
else:
|
else:
|
||||||
self.loaded[plugin_path] = importlib.import_module(import_path)
|
self.loaded[plugin_path] = importlib.import_module(
|
||||||
|
import_path
|
||||||
|
)
|
||||||
self.current_plugin_path = None
|
self.current_plugin_path = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
|
logger.exception(
|
||||||
|
"Failed to import plugin %s: %s" % (plugin_name, e)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
pconf = self.pconf
|
pconf = self.pconf
|
||||||
news = [self.plugins[name] for name in self.plugins]
|
news = [self.plugins[name] for name in self.plugins]
|
||||||
@@ -95,21 +119,28 @@ class PluginManager:
|
|||||||
rawname = plugincls.name
|
rawname = plugincls.name
|
||||||
if rawname not in pconf["plugins"]:
|
if rawname not in pconf["plugins"]:
|
||||||
modified = True
|
modified = True
|
||||||
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
|
logger.info(
|
||||||
pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority}
|
"Plugin %s not found in pconfig, adding to pconfig..." % name
|
||||||
|
)
|
||||||
|
pconf["plugins"][rawname] = {
|
||||||
|
"enabled": plugincls.enabled,
|
||||||
|
"priority": plugincls.priority,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"]
|
self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"]
|
||||||
self.plugins[name].priority = pconf["plugins"][rawname]["priority"]
|
self.plugins[name].priority = pconf["plugins"][rawname]["priority"]
|
||||||
self.plugins._update_heap(name) # 更新下plugins中的顺序
|
self.plugins._update_heap(name) # 更新下plugins中的顺序
|
||||||
if modified:
|
if modified:
|
||||||
self.save_config()
|
self.save_config()
|
||||||
return new_plugins
|
return new_plugins
|
||||||
|
|
||||||
def refresh_order(self):
|
def refresh_order(self):
|
||||||
for event in self.listening_plugins.keys():
|
for event in self.listening_plugins.keys():
|
||||||
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
|
self.listening_plugins[event].sort(
|
||||||
|
key=lambda name: self.plugins[name].priority, reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
def activate_plugins(self): # 生成新开启的插件实例
|
def activate_plugins(self): # 生成新开启的插件实例
|
||||||
failed_plugins = []
|
failed_plugins = []
|
||||||
for name, plugincls in self.plugins.items():
|
for name, plugincls in self.plugins.items():
|
||||||
if plugincls.enabled:
|
if plugincls.enabled:
|
||||||
@@ -129,7 +160,7 @@ class PluginManager:
|
|||||||
self.refresh_order()
|
self.refresh_order()
|
||||||
return failed_plugins
|
return failed_plugins
|
||||||
|
|
||||||
def reload_plugin(self, name:str):
|
def reload_plugin(self, name: str):
|
||||||
name = name.upper()
|
name = name.upper()
|
||||||
if name in self.instances:
|
if name in self.instances:
|
||||||
for event in self.listening_plugins:
|
for event in self.listening_plugins:
|
||||||
@@ -139,13 +170,13 @@ class PluginManager:
|
|||||||
self.activate_plugins()
|
self.activate_plugins()
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_plugins(self):
|
def load_plugins(self):
|
||||||
self.load_config()
|
self.load_config()
|
||||||
self.scan_plugins()
|
self.scan_plugins()
|
||||||
pconf = self.pconf
|
pconf = self.pconf
|
||||||
logger.debug("plugins.json config={}".format(pconf))
|
logger.debug("plugins.json config={}".format(pconf))
|
||||||
for name,plugin in pconf["plugins"].items():
|
for name, plugin in pconf["plugins"].items():
|
||||||
if name.upper() not in self.plugins:
|
if name.upper() not in self.plugins:
|
||||||
logger.error("Plugin %s not found, but found in plugins.json" % name)
|
logger.error("Plugin %s not found, but found in plugins.json" % name)
|
||||||
self.activate_plugins()
|
self.activate_plugins()
|
||||||
@@ -153,13 +184,18 @@ class PluginManager:
|
|||||||
def emit_event(self, e_context: EventContext, *args, **kwargs):
|
def emit_event(self, e_context: EventContext, *args, **kwargs):
|
||||||
if e_context.event in self.listening_plugins:
|
if e_context.event in self.listening_plugins:
|
||||||
for name in self.listening_plugins[e_context.event]:
|
for name in self.listening_plugins[e_context.event]:
|
||||||
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
|
if (
|
||||||
logger.debug("Plugin %s triggered by event %s" % (name,e_context.event))
|
self.plugins[name].enabled
|
||||||
|
and e_context.action == EventAction.CONTINUE
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
"Plugin %s triggered by event %s" % (name, e_context.event)
|
||||||
|
)
|
||||||
instance = self.instances[name]
|
instance = self.instances[name]
|
||||||
instance.handlers[e_context.event](e_context, *args, **kwargs)
|
instance.handlers[e_context.event](e_context, *args, **kwargs)
|
||||||
return e_context
|
return e_context
|
||||||
|
|
||||||
def set_plugin_priority(self, name:str, priority:int):
|
def set_plugin_priority(self, name: str, priority: int):
|
||||||
name = name.upper()
|
name = name.upper()
|
||||||
if name not in self.plugins:
|
if name not in self.plugins:
|
||||||
return False
|
return False
|
||||||
@@ -174,11 +210,11 @@ class PluginManager:
|
|||||||
self.refresh_order()
|
self.refresh_order()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def enable_plugin(self, name:str):
|
def enable_plugin(self, name: str):
|
||||||
name = name.upper()
|
name = name.upper()
|
||||||
if name not in self.plugins:
|
if name not in self.plugins:
|
||||||
return False, "插件不存在"
|
return False, "插件不存在"
|
||||||
if not self.plugins[name].enabled :
|
if not self.plugins[name].enabled:
|
||||||
self.plugins[name].enabled = True
|
self.plugins[name].enabled = True
|
||||||
rawname = self.plugins[name].name
|
rawname = self.plugins[name].name
|
||||||
self.pconf["plugins"][rawname]["enabled"] = True
|
self.pconf["plugins"][rawname]["enabled"] = True
|
||||||
@@ -188,43 +224,47 @@ class PluginManager:
|
|||||||
return False, "插件开启失败"
|
return False, "插件开启失败"
|
||||||
return True, "插件已开启"
|
return True, "插件已开启"
|
||||||
return True, "插件已开启"
|
return True, "插件已开启"
|
||||||
|
|
||||||
def disable_plugin(self, name:str):
|
def disable_plugin(self, name: str):
|
||||||
name = name.upper()
|
name = name.upper()
|
||||||
if name not in self.plugins:
|
if name not in self.plugins:
|
||||||
return False
|
return False
|
||||||
if self.plugins[name].enabled :
|
if self.plugins[name].enabled:
|
||||||
self.plugins[name].enabled = False
|
self.plugins[name].enabled = False
|
||||||
rawname = self.plugins[name].name
|
rawname = self.plugins[name].name
|
||||||
self.pconf["plugins"][rawname]["enabled"] = False
|
self.pconf["plugins"][rawname]["enabled"] = False
|
||||||
self.save_config()
|
self.save_config()
|
||||||
return True
|
return True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def list_plugins(self):
|
def list_plugins(self):
|
||||||
return self.plugins
|
return self.plugins
|
||||||
|
|
||||||
def install_plugin(self, repo:str):
|
def install_plugin(self, repo: str):
|
||||||
try:
|
try:
|
||||||
import common.package_manager as pkgmgr
|
import common.package_manager as pkgmgr
|
||||||
|
|
||||||
pkgmgr.check_dulwich()
|
pkgmgr.check_dulwich()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to install plugin, {}".format(e))
|
logger.error("Failed to install plugin, {}".format(e))
|
||||||
return False, "无法导入dulwich,安装插件失败"
|
return False, "无法导入dulwich,安装插件失败"
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from dulwich import porcelain
|
from dulwich import porcelain
|
||||||
|
|
||||||
logger.info("clone git repo: {}".format(repo))
|
logger.info("clone git repo: {}".format(repo))
|
||||||
|
|
||||||
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
|
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
|
||||||
|
|
||||||
if not match:
|
if not match:
|
||||||
try:
|
try:
|
||||||
with open("./plugins/source.json","r", encoding="utf-8") as f:
|
with open("./plugins/source.json", "r", encoding="utf-8") as f:
|
||||||
source = json.load(f)
|
source = json.load(f)
|
||||||
if repo in source["repo"]:
|
if repo in source["repo"]:
|
||||||
repo = source["repo"][repo]["url"]
|
repo = source["repo"][repo]["url"]
|
||||||
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
|
match = re.match(
|
||||||
|
r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo
|
||||||
|
)
|
||||||
if not match:
|
if not match:
|
||||||
return False, "安装插件失败,source中的仓库地址不合法"
|
return False, "安装插件失败,source中的仓库地址不合法"
|
||||||
else:
|
else:
|
||||||
@@ -232,42 +272,53 @@ class PluginManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to install plugin, {}".format(e))
|
logger.error("Failed to install plugin, {}".format(e))
|
||||||
return False, "安装插件失败,请检查仓库地址是否正确"
|
return False, "安装插件失败,请检查仓库地址是否正确"
|
||||||
dirname = os.path.join("./plugins",match.group(4))
|
dirname = os.path.join("./plugins", match.group(4))
|
||||||
try:
|
try:
|
||||||
repo = porcelain.clone(repo, dirname, checkout=True)
|
repo = porcelain.clone(repo, dirname, checkout=True)
|
||||||
if os.path.exists(os.path.join(dirname,"requirements.txt")):
|
if os.path.exists(os.path.join(dirname, "requirements.txt")):
|
||||||
logger.info("detect requirements.txt,installing...")
|
logger.info("detect requirements.txt,installing...")
|
||||||
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt"))
|
pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt"))
|
||||||
return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置"
|
return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to install plugin, {}".format(e))
|
logger.error("Failed to install plugin, {}".format(e))
|
||||||
return False, "安装插件失败,"+str(e)
|
return False, "安装插件失败," + str(e)
|
||||||
|
|
||||||
def update_plugin(self, name:str):
|
def update_plugin(self, name: str):
|
||||||
try:
|
try:
|
||||||
import common.package_manager as pkgmgr
|
import common.package_manager as pkgmgr
|
||||||
|
|
||||||
pkgmgr.check_dulwich()
|
pkgmgr.check_dulwich()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to install plugin, {}".format(e))
|
logger.error("Failed to install plugin, {}".format(e))
|
||||||
return False, "无法导入dulwich,更新插件失败"
|
return False, "无法导入dulwich,更新插件失败"
|
||||||
from dulwich import porcelain
|
from dulwich import porcelain
|
||||||
|
|
||||||
name = name.upper()
|
name = name.upper()
|
||||||
if name not in self.plugins:
|
if name not in self.plugins:
|
||||||
return False, "插件不存在"
|
return False, "插件不存在"
|
||||||
if name in ["HELLO","GODCMD","ROLE","TOOL","BDUNIT","BANWORDS","FINISH","DUNGEON"]:
|
if name in [
|
||||||
|
"HELLO",
|
||||||
|
"GODCMD",
|
||||||
|
"ROLE",
|
||||||
|
"TOOL",
|
||||||
|
"BDUNIT",
|
||||||
|
"BANWORDS",
|
||||||
|
"FINISH",
|
||||||
|
"DUNGEON",
|
||||||
|
]:
|
||||||
return False, "预置插件无法更新,请更新主程序仓库"
|
return False, "预置插件无法更新,请更新主程序仓库"
|
||||||
dirname = self.plugins[name].path
|
dirname = self.plugins[name].path
|
||||||
try:
|
try:
|
||||||
porcelain.pull(dirname, "origin")
|
porcelain.pull(dirname, "origin")
|
||||||
if os.path.exists(os.path.join(dirname,"requirements.txt")):
|
if os.path.exists(os.path.join(dirname, "requirements.txt")):
|
||||||
logger.info("detect requirements.txt,installing...")
|
logger.info("detect requirements.txt,installing...")
|
||||||
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt"))
|
pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt"))
|
||||||
return True, "更新插件成功,请重新运行程序"
|
return True, "更新插件成功,请重新运行程序"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to update plugin, {}".format(e))
|
logger.error("Failed to update plugin, {}".format(e))
|
||||||
return False, "更新插件失败,"+str(e)
|
return False, "更新插件失败," + str(e)
|
||||||
|
|
||||||
def uninstall_plugin(self, name:str):
|
def uninstall_plugin(self, name: str):
|
||||||
name = name.upper()
|
name = name.upper()
|
||||||
if name not in self.plugins:
|
if name not in self.plugins:
|
||||||
return False, "插件不存在"
|
return False, "插件不存在"
|
||||||
@@ -276,6 +327,7 @@ class PluginManager:
|
|||||||
dirname = self.plugins[name].path
|
dirname = self.plugins[name].path
|
||||||
try:
|
try:
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.rmtree(dirname)
|
shutil.rmtree(dirname)
|
||||||
rawname = self.plugins[name].name
|
rawname = self.plugins[name].name
|
||||||
for event in self.listening_plugins:
|
for event in self.listening_plugins:
|
||||||
@@ -288,4 +340,4 @@ class PluginManager:
|
|||||||
return True, "卸载插件成功"
|
return True, "卸载插件成功"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to uninstall plugin, {}".format(e))
|
logger.error("Failed to uninstall plugin, {}".format(e))
|
||||||
return False, "卸载插件失败,请手动删除文件夹完成卸载,"+str(e)
|
return False, "卸载插件失败,请手动删除文件夹完成卸载," + str(e)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from .role import *
|
from .role import *
|
||||||
|
|||||||
@@ -2,17 +2,18 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import plugins
|
||||||
from bridge.bridge import Bridge
|
from bridge.bridge import Bridge
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from common import const
|
from common import const
|
||||||
from config import conf
|
|
||||||
import plugins
|
|
||||||
from plugins import *
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from config import conf
|
||||||
|
from plugins import *
|
||||||
|
|
||||||
|
|
||||||
class RolePlay():
|
class RolePlay:
|
||||||
def __init__(self, bot, sessionid, desc, wrapper=None):
|
def __init__(self, bot, sessionid, desc, wrapper=None):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.sessionid = sessionid
|
self.sessionid = sessionid
|
||||||
@@ -25,12 +26,20 @@ class RolePlay():
|
|||||||
|
|
||||||
def action(self, user_action):
|
def action(self, user_action):
|
||||||
session = self.bot.sessions.build_session(self.sessionid)
|
session = self.bot.sessions.build_session(self.sessionid)
|
||||||
if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置
|
if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置
|
||||||
session.set_system_prompt(self.desc)
|
session.set_system_prompt(self.desc)
|
||||||
prompt = self.wrapper % user_action
|
prompt = self.wrapper % user_action
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
@plugins.register(name="Role", desire_priority=0, namecn="角色扮演", desc="为你的Bot设置预设角色", version="1.0", author="lanvent")
|
|
||||||
|
@plugins.register(
|
||||||
|
name="Role",
|
||||||
|
desire_priority=0,
|
||||||
|
namecn="角色扮演",
|
||||||
|
desc="为你的Bot设置预设角色",
|
||||||
|
version="1.0",
|
||||||
|
author="lanvent",
|
||||||
|
)
|
||||||
class Role(Plugin):
|
class Role(Plugin):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -39,7 +48,7 @@ class Role(Plugin):
|
|||||||
try:
|
try:
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
self.tags = { tag:(desc,[]) for tag,desc in config["tags"].items()}
|
self.tags = {tag: (desc, []) for tag, desc in config["tags"].items()}
|
||||||
self.roles = {}
|
self.roles = {}
|
||||||
for role in config["roles"]:
|
for role in config["roles"]:
|
||||||
self.roles[role["title"].lower()] = role
|
self.roles[role["title"].lower()] = role
|
||||||
@@ -60,12 +69,16 @@ class Role(Plugin):
|
|||||||
logger.info("[Role] inited")
|
logger.info("[Role] inited")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, FileNotFoundError):
|
if isinstance(e, FileNotFoundError):
|
||||||
logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
|
logger.warn(
|
||||||
|
f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
|
logger.warn(
|
||||||
|
"[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_role(self, name, find_closest=True, min_sim = 0.35):
|
def get_role(self, name, find_closest=True, min_sim=0.35):
|
||||||
name = name.lower()
|
name = name.lower()
|
||||||
found_role = None
|
found_role = None
|
||||||
if name in self.roles:
|
if name in self.roles:
|
||||||
@@ -75,6 +88,7 @@ class Role(Plugin):
|
|||||||
|
|
||||||
def str_simularity(a, b):
|
def str_simularity(a, b):
|
||||||
return difflib.SequenceMatcher(None, a, b).ratio()
|
return difflib.SequenceMatcher(None, a, b).ratio()
|
||||||
|
|
||||||
max_sim = min_sim
|
max_sim = min_sim
|
||||||
max_role = None
|
max_role = None
|
||||||
for role in self.roles:
|
for role in self.roles:
|
||||||
@@ -86,25 +100,24 @@ class Role(Plugin):
|
|||||||
return found_role
|
return found_role
|
||||||
|
|
||||||
def on_handle_context(self, e_context: EventContext):
|
def on_handle_context(self, e_context: EventContext):
|
||||||
|
if e_context["context"].type != ContextType.TEXT:
|
||||||
if e_context['context'].type != ContextType.TEXT:
|
|
||||||
return
|
return
|
||||||
bottype = Bridge().get_bot_type("chat")
|
bottype = Bridge().get_bot_type("chat")
|
||||||
if bottype not in (const.CHATGPT, const.OPEN_AI):
|
if bottype not in (const.CHATGPT, const.OPEN_AI):
|
||||||
return
|
return
|
||||||
bot = Bridge().get_bot("chat")
|
bot = Bridge().get_bot("chat")
|
||||||
content = e_context['context'].content[:]
|
content = e_context["context"].content[:]
|
||||||
clist = e_context['context'].content.split(maxsplit=1)
|
clist = e_context["context"].content.split(maxsplit=1)
|
||||||
desckey = None
|
desckey = None
|
||||||
customize = False
|
customize = False
|
||||||
sessionid = e_context['context']['session_id']
|
sessionid = e_context["context"]["session_id"]
|
||||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||||
if clist[0] == f"{trigger_prefix}停止扮演":
|
if clist[0] == f"{trigger_prefix}停止扮演":
|
||||||
if sessionid in self.roleplays:
|
if sessionid in self.roleplays:
|
||||||
self.roleplays[sessionid].reset()
|
self.roleplays[sessionid].reset()
|
||||||
del self.roleplays[sessionid]
|
del self.roleplays[sessionid]
|
||||||
reply = Reply(ReplyType.INFO, "角色扮演结束!")
|
reply = Reply(ReplyType.INFO, "角色扮演结束!")
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
return
|
return
|
||||||
elif clist[0] == f"{trigger_prefix}角色":
|
elif clist[0] == f"{trigger_prefix}角色":
|
||||||
@@ -114,10 +127,10 @@ class Role(Plugin):
|
|||||||
elif clist[0] == f"{trigger_prefix}设定扮演":
|
elif clist[0] == f"{trigger_prefix}设定扮演":
|
||||||
customize = True
|
customize = True
|
||||||
elif clist[0] == f"{trigger_prefix}角色类型":
|
elif clist[0] == f"{trigger_prefix}角色类型":
|
||||||
if len(clist) >1:
|
if len(clist) > 1:
|
||||||
tag = clist[1].strip()
|
tag = clist[1].strip()
|
||||||
help_text = "角色列表:\n"
|
help_text = "角色列表:\n"
|
||||||
for key,value in self.tags.items():
|
for key, value in self.tags.items():
|
||||||
if value[0] == tag:
|
if value[0] == tag:
|
||||||
tag = key
|
tag = key
|
||||||
break
|
break
|
||||||
@@ -130,57 +143,75 @@ class Role(Plugin):
|
|||||||
else:
|
else:
|
||||||
help_text = f"未知角色类型。\n"
|
help_text = f"未知角色类型。\n"
|
||||||
help_text += "目前的角色类型有: \n"
|
help_text += "目前的角色类型有: \n"
|
||||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n"
|
help_text += (
|
||||||
|
",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
help_text = f"请输入角色类型。\n"
|
help_text = f"请输入角色类型。\n"
|
||||||
help_text += "目前的角色类型有: \n"
|
help_text += "目前的角色类型有: \n"
|
||||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n"
|
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
|
||||||
reply = Reply(ReplyType.INFO, help_text)
|
reply = Reply(ReplyType.INFO, help_text)
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
return
|
return
|
||||||
elif sessionid not in self.roleplays:
|
elif sessionid not in self.roleplays:
|
||||||
return
|
return
|
||||||
logger.debug("[Role] on_handle_context. content: %s" % content)
|
logger.debug("[Role] on_handle_context. content: %s" % content)
|
||||||
if desckey is not None:
|
if desckey is not None:
|
||||||
if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]):
|
if len(clist) == 1 or (
|
||||||
|
len(clist) > 1 and clist[1].lower() in ["help", "帮助"]
|
||||||
|
):
|
||||||
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
|
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
return
|
return
|
||||||
role = self.get_role(clist[1])
|
role = self.get_role(clist[1])
|
||||||
if role is None:
|
if role is None:
|
||||||
reply = Reply(ReplyType.ERROR, "角色不存在")
|
reply = Reply(ReplyType.ERROR, "角色不存在")
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.roleplays[sessionid] = RolePlay(bot, sessionid, self.roles[role][desckey], self.roles[role].get("wrapper","%s"))
|
self.roleplays[sessionid] = RolePlay(
|
||||||
reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n"+self.roles[role][desckey])
|
bot,
|
||||||
e_context['reply'] = reply
|
sessionid,
|
||||||
|
self.roles[role][desckey],
|
||||||
|
self.roles[role].get("wrapper", "%s"),
|
||||||
|
)
|
||||||
|
reply = Reply(
|
||||||
|
ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]
|
||||||
|
)
|
||||||
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
elif customize == True:
|
elif customize == True:
|
||||||
self.roleplays[sessionid] = RolePlay(bot, sessionid, clist[1], "%s")
|
self.roleplays[sessionid] = RolePlay(bot, sessionid, clist[1], "%s")
|
||||||
reply = Reply(ReplyType.INFO, f"角色设定为:\n{clist[1]}")
|
reply = Reply(ReplyType.INFO, f"角色设定为:\n{clist[1]}")
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
else:
|
else:
|
||||||
prompt = self.roleplays[sessionid].action(content)
|
prompt = self.roleplays[sessionid].action(content)
|
||||||
e_context['context'].type = ContextType.TEXT
|
e_context["context"].type = ContextType.TEXT
|
||||||
e_context['context'].content = prompt
|
e_context["context"].content = prompt
|
||||||
e_context.action = EventAction.BREAK
|
e_context.action = EventAction.BREAK
|
||||||
|
|
||||||
def get_help_text(self, verbose=False, **kwargs):
|
def get_help_text(self, verbose=False, **kwargs):
|
||||||
help_text = "让机器人扮演不同的角色。\n"
|
help_text = "让机器人扮演不同的角色。\n"
|
||||||
if not verbose:
|
if not verbose:
|
||||||
return help_text
|
return help_text
|
||||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||||
help_text = f"使用方法:\n{trigger_prefix}角色"+" 预设角色名: 设定角色为{预设角色名}。\n"+f"{trigger_prefix}role"+" 预设角色名: 同上,但使用英文设定。\n"
|
help_text = (
|
||||||
help_text += f"{trigger_prefix}设定扮演"+" 角色设定: 设定自定义角色人设为{角色设定}。\n"
|
f"使用方法:\n{trigger_prefix}角色"
|
||||||
|
+ " 预设角色名: 设定角色为{预设角色名}。\n"
|
||||||
|
+ f"{trigger_prefix}role"
|
||||||
|
+ " 预设角色名: 同上,但使用英文设定。\n"
|
||||||
|
)
|
||||||
|
help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n"
|
||||||
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
|
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
|
||||||
help_text += f"{trigger_prefix}角色类型"+" 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
|
help_text += (
|
||||||
|
f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
|
||||||
|
)
|
||||||
help_text += "\n目前的角色类型有: \n"
|
help_text += "\n目前的角色类型有: \n"
|
||||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"。\n"
|
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n"
|
||||||
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"
|
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"
|
||||||
help_text += f"{trigger_prefix}角色类型 所有\n"
|
help_text += f"{trigger_prefix}角色类型 所有\n"
|
||||||
help_text += f"{trigger_prefix}停止扮演\n"
|
help_text += f"{trigger_prefix}停止扮演\n"
|
||||||
|
|||||||
@@ -428,4 +428,4 @@
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
{
|
{
|
||||||
"repo": {
|
"repo": {
|
||||||
"sdwebui": {
|
"sdwebui": {
|
||||||
"url": "https://github.com/lanvent/plugin_sdwebui.git",
|
"url": "https://github.com/lanvent/plugin_sdwebui.git",
|
||||||
"desc": "利用stable-diffusion画图的插件"
|
"desc": "利用stable-diffusion画图的插件"
|
||||||
},
|
},
|
||||||
"replicate": {
|
"replicate": {
|
||||||
"url": "https://github.com/lanvent/plugin_replicate.git",
|
"url": "https://github.com/lanvent/plugin_replicate.git",
|
||||||
"desc": "利用replicate api画图的插件"
|
"desc": "利用replicate api画图的插件"
|
||||||
},
|
},
|
||||||
"summary": {
|
"summary": {
|
||||||
"url": "https://github.com/lanvent/plugin_summary.git",
|
"url": "https://github.com/lanvent/plugin_summary.git",
|
||||||
"desc": "总结聊天记录的插件"
|
"desc": "总结聊天记录的插件"
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
## 插件描述
|
## 插件描述
|
||||||
一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力
|
一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力
|
||||||
使用该插件需在机器人回复你的前提下,在对话内容前加$tool;仅输入$tool将返回tool插件帮助信息,用于测试插件是否加载成功
|
使用该插件需在机器人回复你的前提下,在对话内容前加$tool;仅输入$tool将返回tool插件帮助信息,用于测试插件是否加载成功
|
||||||
### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)
|
### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)
|
||||||
|
|
||||||
|
|
||||||
## 使用说明
|
## 使用说明
|
||||||
使用该插件后将默认使用4个工具, 无需额外配置长期生效:
|
使用该插件后将默认使用4个工具, 无需额外配置长期生效:
|
||||||
### 1. python
|
### 1. python
|
||||||
###### python解释器,使用它来解释执行python指令,可以配合你想要chatgpt生成的代码输出结果或执行事务
|
###### python解释器,使用它来解释执行python指令,可以配合你想要chatgpt生成的代码输出结果或执行事务
|
||||||
|
|
||||||
### 2. url-get
|
### 2. url-get
|
||||||
###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响
|
###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响
|
||||||
|
|
||||||
@@ -23,16 +23,16 @@
|
|||||||
|
|
||||||
> meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334
|
> meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334
|
||||||
|
|
||||||
## 使用本插件对话(prompt)技巧
|
## 使用本插件对话(prompt)技巧
|
||||||
### 1. 有指引的询问
|
### 1. 有指引的询问
|
||||||
#### 例如:
|
#### 例如:
|
||||||
- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub
|
- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub
|
||||||
- 使用Terminal执行curl cip.cc
|
- 使用Terminal执行curl cip.cc
|
||||||
- 使用python查询今天日期
|
- 使用python查询今天日期
|
||||||
|
|
||||||
### 2. 使用搜索引擎工具
|
### 2. 使用搜索引擎工具
|
||||||
- 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息,比如chatgpt不知道你的地理位置,现在时间等,所以无法查询到天气
|
- 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息,比如chatgpt不知道你的地理位置,现在时间等,所以无法查询到天气
|
||||||
|
|
||||||
## 其他工具
|
## 其他工具
|
||||||
|
|
||||||
### 5. wikipedia
|
### 5. wikipedia
|
||||||
@@ -55,9 +55,9 @@
|
|||||||
### 10. google-search *
|
### 10. google-search *
|
||||||
###### google搜索引擎,申请流程较bing-search繁琐
|
###### google搜索引擎,申请流程较bing-search繁琐
|
||||||
|
|
||||||
###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持
|
###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持
|
||||||
#### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md)
|
#### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md)
|
||||||
|
|
||||||
## config.json 配置说明
|
## config.json 配置说明
|
||||||
###### 默认工具无需配置,其它工具需手动配置,一个例子:
|
###### 默认工具无需配置,其它工具需手动配置,一个例子:
|
||||||
```json
|
```json
|
||||||
@@ -71,15 +71,15 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对
|
注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对
|
||||||
- `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news", "morning-news"] & 默认工具,除wikipedia工具之外均需要申请api-key
|
- `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news", "morning-news"] & 默认工具,除wikipedia工具之外均需要申请api-key
|
||||||
- `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置
|
- `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置
|
||||||
- `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置
|
- `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置
|
||||||
- `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具
|
- `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具
|
||||||
- `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2
|
- `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2
|
||||||
- `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认
|
- `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认
|
||||||
|
|
||||||
|
|
||||||
## 备注
|
## 备注
|
||||||
- 强烈建议申请搜索工具搭配使用,推荐bing-search
|
- 强烈建议申请搜索工具搭配使用,推荐bing-search
|
||||||
- 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤
|
- 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from .tool import *
|
from .tool import *
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
{
|
{
|
||||||
"tools": ["python", "url-get", "terminal", "meteo-weather"],
|
"tools": [
|
||||||
|
"python",
|
||||||
|
"url-get",
|
||||||
|
"terminal",
|
||||||
|
"meteo-weather"
|
||||||
|
],
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"top_k_results": 2,
|
"top_k_results": 2,
|
||||||
"no_default": false,
|
"no_default": false,
|
||||||
"model_name": "gpt-3.5-turbo"
|
"model_name": "gpt-3.5-turbo"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
from chatgpt_tool_hub.apps import load_app
|
from chatgpt_tool_hub.apps import load_app
|
||||||
from chatgpt_tool_hub.apps.app import App
|
from chatgpt_tool_hub.apps.app import App
|
||||||
from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names
|
from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names
|
||||||
|
|
||||||
import plugins
|
import plugins
|
||||||
from bridge.bridge import Bridge
|
from bridge.bridge import Bridge
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
@@ -14,7 +15,13 @@ from config import conf
|
|||||||
from plugins import *
|
from plugins import *
|
||||||
|
|
||||||
|
|
||||||
@plugins.register(name="tool", desc="Arming your ChatGPT bot with various tools", version="0.3", author="goldfishh", desire_priority=0)
|
@plugins.register(
|
||||||
|
name="tool",
|
||||||
|
desc="Arming your ChatGPT bot with various tools",
|
||||||
|
version="0.3",
|
||||||
|
author="goldfishh",
|
||||||
|
desire_priority=0,
|
||||||
|
)
|
||||||
class Tool(Plugin):
|
class Tool(Plugin):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -28,22 +35,26 @@ class Tool(Plugin):
|
|||||||
help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。"
|
help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。"
|
||||||
if not verbose:
|
if not verbose:
|
||||||
return help_text
|
return help_text
|
||||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||||
help_text += "使用说明:\n"
|
help_text += "使用说明:\n"
|
||||||
help_text += f"{trigger_prefix}tool "+"命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n"
|
help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n"
|
||||||
help_text += f"{trigger_prefix}tool reset: 重置工具。\n"
|
help_text += f"{trigger_prefix}tool reset: 重置工具。\n"
|
||||||
return help_text
|
return help_text
|
||||||
|
|
||||||
def on_handle_context(self, e_context: EventContext):
|
def on_handle_context(self, e_context: EventContext):
|
||||||
if e_context['context'].type != ContextType.TEXT:
|
if e_context["context"].type != ContextType.TEXT:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 暂时不支持未来扩展的bot
|
# 暂时不支持未来扩展的bot
|
||||||
if Bridge().get_bot_type("chat") not in (const.CHATGPT, const.OPEN_AI, const.CHATGPTONAZURE):
|
if Bridge().get_bot_type("chat") not in (
|
||||||
|
const.CHATGPT,
|
||||||
|
const.OPEN_AI,
|
||||||
|
const.CHATGPTONAZURE,
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
content = e_context['context'].content
|
content = e_context["context"].content
|
||||||
content_list = e_context['context'].content.split(maxsplit=1)
|
content_list = e_context["context"].content.split(maxsplit=1)
|
||||||
|
|
||||||
if not content or len(content_list) < 1:
|
if not content or len(content_list) < 1:
|
||||||
e_context.action = EventAction.CONTINUE
|
e_context.action = EventAction.CONTINUE
|
||||||
@@ -52,13 +63,13 @@ class Tool(Plugin):
|
|||||||
logger.debug("[tool] on_handle_context. content: %s" % content)
|
logger.debug("[tool] on_handle_context. content: %s" % content)
|
||||||
reply = Reply()
|
reply = Reply()
|
||||||
reply.type = ReplyType.TEXT
|
reply.type = ReplyType.TEXT
|
||||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
|
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||||
# todo: 有些工具必须要api-key,需要修改config文件,所以这里没有实现query增删tool的功能
|
# todo: 有些工具必须要api-key,需要修改config文件,所以这里没有实现query增删tool的功能
|
||||||
if content.startswith(f"{trigger_prefix}tool"):
|
if content.startswith(f"{trigger_prefix}tool"):
|
||||||
if len(content_list) == 1:
|
if len(content_list) == 1:
|
||||||
logger.debug("[tool]: get help")
|
logger.debug("[tool]: get help")
|
||||||
reply.content = self.get_help_text()
|
reply.content = self.get_help_text()
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
return
|
return
|
||||||
elif len(content_list) > 1:
|
elif len(content_list) > 1:
|
||||||
@@ -66,12 +77,14 @@ class Tool(Plugin):
|
|||||||
logger.debug("[tool]: reset config")
|
logger.debug("[tool]: reset config")
|
||||||
self.app = self._reset_app()
|
self.app = self._reset_app()
|
||||||
reply.content = "重置工具成功"
|
reply.content = "重置工具成功"
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
return
|
return
|
||||||
elif content_list[1].startswith("reset"):
|
elif content_list[1].startswith("reset"):
|
||||||
logger.debug("[tool]: remind")
|
logger.debug("[tool]: remind")
|
||||||
e_context['context'].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
|
e_context[
|
||||||
|
"context"
|
||||||
|
].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
|
||||||
|
|
||||||
e_context.action = EventAction.BREAK
|
e_context.action = EventAction.BREAK
|
||||||
return
|
return
|
||||||
@@ -80,34 +93,35 @@ class Tool(Plugin):
|
|||||||
|
|
||||||
# Don't modify bot name
|
# Don't modify bot name
|
||||||
all_sessions = Bridge().get_bot("chat").sessions
|
all_sessions = Bridge().get_bot("chat").sessions
|
||||||
user_session = all_sessions.session_query(query, e_context['context']['session_id']).messages
|
user_session = all_sessions.session_query(
|
||||||
|
query, e_context["context"]["session_id"]
|
||||||
|
).messages
|
||||||
|
|
||||||
# chatgpt-tool-hub will reply you with many tools
|
# chatgpt-tool-hub will reply you with many tools
|
||||||
logger.debug("[tool]: just-go")
|
logger.debug("[tool]: just-go")
|
||||||
try:
|
try:
|
||||||
_reply = self.app.ask(query, user_session)
|
_reply = self.app.ask(query, user_session)
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
all_sessions.session_reply(_reply, e_context['context']['session_id'])
|
all_sessions.session_reply(
|
||||||
|
_reply, e_context["context"]["session_id"]
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
|
|
||||||
e_context['context'].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理"
|
e_context["context"].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理"
|
||||||
reply.type = ReplyType.ERROR
|
reply.type = ReplyType.ERROR
|
||||||
e_context.action = EventAction.BREAK
|
e_context.action = EventAction.BREAK
|
||||||
return
|
return
|
||||||
|
|
||||||
reply.content = _reply
|
reply.content = _reply
|
||||||
e_context['reply'] = reply
|
e_context["reply"] = reply
|
||||||
return
|
return
|
||||||
|
|
||||||
def _read_json(self) -> dict:
|
def _read_json(self) -> dict:
|
||||||
curdir = os.path.dirname(__file__)
|
curdir = os.path.dirname(__file__)
|
||||||
config_path = os.path.join(curdir, "config.json")
|
config_path = os.path.join(curdir, "config.json")
|
||||||
tool_config = {
|
tool_config = {"tools": [], "kwargs": {}}
|
||||||
"tools": [],
|
|
||||||
"kwargs": {}
|
|
||||||
}
|
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
return tool_config
|
return tool_config
|
||||||
else:
|
else:
|
||||||
@@ -123,7 +137,9 @@ class Tool(Plugin):
|
|||||||
"proxy": conf().get("proxy", ""),
|
"proxy": conf().get("proxy", ""),
|
||||||
"request_timeout": conf().get("request_timeout", 60),
|
"request_timeout": conf().get("request_timeout", 60),
|
||||||
# note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置
|
# note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置
|
||||||
"model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"),
|
"model_name": tool_model_name
|
||||||
|
if tool_model_name
|
||||||
|
else conf().get("model", "gpt-3.5-turbo"),
|
||||||
"no_default": kwargs.get("no_default", False),
|
"no_default": kwargs.get("no_default", False),
|
||||||
"top_k_results": kwargs.get("top_k_results", 2),
|
"top_k_results": kwargs.get("top_k_results", 2),
|
||||||
# for news tool
|
# for news tool
|
||||||
@@ -160,4 +176,7 @@ class Tool(Plugin):
|
|||||||
# filter not support tool
|
# filter not support tool
|
||||||
tool_list = self._filter_tool_list(tool_config.get("tools", []))
|
tool_list = self._filter_tool_list(tool_config.get("tools", []))
|
||||||
|
|
||||||
return load_app(tools_list=tool_list, **self._build_tool_kwargs(tool_config.get("kwargs", {})))
|
return load_app(
|
||||||
|
tools_list=tool_list,
|
||||||
|
**self._build_tool_kwargs(tool_config.get("kwargs", {})),
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,3 +4,4 @@ PyQRCode>=1.2.1
|
|||||||
qrcode>=7.4.2
|
qrcode>=7.4.2
|
||||||
requests>=2.28.2
|
requests>=2.28.2
|
||||||
chardet>=5.1.0
|
chardet>=5.1.0
|
||||||
|
pre-commit
|
||||||
@@ -8,7 +8,7 @@ echo $BASE_DIR
|
|||||||
# check the nohup.out log output file
|
# check the nohup.out log output file
|
||||||
if [ ! -f "${BASE_DIR}/nohup.out" ]; then
|
if [ ! -f "${BASE_DIR}/nohup.out" ]; then
|
||||||
touch "${BASE_DIR}/nohup.out"
|
touch "${BASE_DIR}/nohup.out"
|
||||||
echo "create file ${BASE_DIR}/nohup.out"
|
echo "create file ${BASE_DIR}/nohup.out"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out"
|
nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out"
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ echo $BASE_DIR
|
|||||||
|
|
||||||
# check the nohup.out log output file
|
# check the nohup.out log output file
|
||||||
if [ ! -f "${BASE_DIR}/nohup.out" ]; then
|
if [ ! -f "${BASE_DIR}/nohup.out" ]; then
|
||||||
echo "No file ${BASE_DIR}/nohup.out"
|
echo "No file ${BASE_DIR}/nohup.out"
|
||||||
exit -1;
|
exit -1;
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import shutil
|
import shutil
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
import pysilk
|
import pysilk
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
|
||||||
sil_supports=[8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率
|
sil_supports = [8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率
|
||||||
|
|
||||||
|
|
||||||
def find_closest_sil_supports(sample_rate):
|
def find_closest_sil_supports(sample_rate):
|
||||||
"""
|
"""
|
||||||
找到最接近的支持的采样率
|
找到最接近的支持的采样率
|
||||||
@@ -19,6 +22,7 @@ def find_closest_sil_supports(sample_rate):
|
|||||||
mindiff = diff
|
mindiff = diff
|
||||||
return closest
|
return closest
|
||||||
|
|
||||||
|
|
||||||
def get_pcm_from_wav(wav_path):
|
def get_pcm_from_wav(wav_path):
|
||||||
"""
|
"""
|
||||||
从 wav 文件中读取 pcm
|
从 wav 文件中读取 pcm
|
||||||
@@ -29,31 +33,42 @@ def get_pcm_from_wav(wav_path):
|
|||||||
wav = wave.open(wav_path, "rb")
|
wav = wave.open(wav_path, "rb")
|
||||||
return wav.readframes(wav.getnframes())
|
return wav.readframes(wav.getnframes())
|
||||||
|
|
||||||
|
|
||||||
def any_to_wav(any_path, wav_path):
|
def any_to_wav(any_path, wav_path):
|
||||||
"""
|
"""
|
||||||
把任意格式转成wav文件
|
把任意格式转成wav文件
|
||||||
"""
|
"""
|
||||||
if any_path.endswith('.wav'):
|
if any_path.endswith(".wav"):
|
||||||
shutil.copy2(any_path, wav_path)
|
shutil.copy2(any_path, wav_path)
|
||||||
return
|
return
|
||||||
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
|
if (
|
||||||
|
any_path.endswith(".sil")
|
||||||
|
or any_path.endswith(".silk")
|
||||||
|
or any_path.endswith(".slk")
|
||||||
|
):
|
||||||
return sil_to_wav(any_path, wav_path)
|
return sil_to_wav(any_path, wav_path)
|
||||||
audio = AudioSegment.from_file(any_path)
|
audio = AudioSegment.from_file(any_path)
|
||||||
audio.export(wav_path, format="wav")
|
audio.export(wav_path, format="wav")
|
||||||
|
|
||||||
|
|
||||||
def any_to_sil(any_path, sil_path):
|
def any_to_sil(any_path, sil_path):
|
||||||
"""
|
"""
|
||||||
把任意格式转成sil文件
|
把任意格式转成sil文件
|
||||||
"""
|
"""
|
||||||
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
|
if (
|
||||||
|
any_path.endswith(".sil")
|
||||||
|
or any_path.endswith(".silk")
|
||||||
|
or any_path.endswith(".slk")
|
||||||
|
):
|
||||||
shutil.copy2(any_path, sil_path)
|
shutil.copy2(any_path, sil_path)
|
||||||
return 10000
|
return 10000
|
||||||
if any_path.endswith('.wav'):
|
if any_path.endswith(".wav"):
|
||||||
return pcm_to_sil(any_path, sil_path)
|
return pcm_to_sil(any_path, sil_path)
|
||||||
if any_path.endswith('.mp3'):
|
if any_path.endswith(".mp3"):
|
||||||
return mp3_to_sil(any_path, sil_path)
|
return mp3_to_sil(any_path, sil_path)
|
||||||
raise NotImplementedError("Not support file type: {}".format(any_path))
|
raise NotImplementedError("Not support file type: {}".format(any_path))
|
||||||
|
|
||||||
|
|
||||||
def mp3_to_wav(mp3_path, wav_path):
|
def mp3_to_wav(mp3_path, wav_path):
|
||||||
"""
|
"""
|
||||||
把mp3格式转成pcm文件
|
把mp3格式转成pcm文件
|
||||||
@@ -61,6 +76,7 @@ def mp3_to_wav(mp3_path, wav_path):
|
|||||||
audio = AudioSegment.from_mp3(mp3_path)
|
audio = AudioSegment.from_mp3(mp3_path)
|
||||||
audio.export(wav_path, format="wav")
|
audio.export(wav_path, format="wav")
|
||||||
|
|
||||||
|
|
||||||
def pcm_to_sil(pcm_path, silk_path):
|
def pcm_to_sil(pcm_path, silk_path):
|
||||||
"""
|
"""
|
||||||
wav 文件转成 silk
|
wav 文件转成 silk
|
||||||
@@ -72,12 +88,12 @@ def pcm_to_sil(pcm_path, silk_path):
|
|||||||
pcm_s16 = audio.set_sample_width(2)
|
pcm_s16 = audio.set_sample_width(2)
|
||||||
pcm_s16 = pcm_s16.set_frame_rate(rate)
|
pcm_s16 = pcm_s16.set_frame_rate(rate)
|
||||||
wav_data = pcm_s16.raw_data
|
wav_data = pcm_s16.raw_data
|
||||||
silk_data = pysilk.encode(
|
silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate)
|
||||||
wav_data, data_rate=rate, sample_rate=rate)
|
|
||||||
with open(silk_path, "wb") as f:
|
with open(silk_path, "wb") as f:
|
||||||
f.write(silk_data)
|
f.write(silk_data)
|
||||||
return audio.duration_seconds * 1000
|
return audio.duration_seconds * 1000
|
||||||
|
|
||||||
|
|
||||||
def mp3_to_sil(mp3_path, silk_path):
|
def mp3_to_sil(mp3_path, silk_path):
|
||||||
"""
|
"""
|
||||||
mp3 文件转成 silk
|
mp3 文件转成 silk
|
||||||
@@ -95,6 +111,7 @@ def mp3_to_sil(mp3_path, silk_path):
|
|||||||
f.write(silk_data)
|
f.write(silk_data)
|
||||||
return audio.duration_seconds * 1000
|
return audio.duration_seconds * 1000
|
||||||
|
|
||||||
|
|
||||||
def sil_to_wav(silk_path, wav_path, rate: int = 24000):
|
def sil_to_wav(silk_path, wav_path, rate: int = 24000):
|
||||||
"""
|
"""
|
||||||
silk 文件转 wav
|
silk 文件转 wav
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
azure voice service
|
azure voice service
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import azure.cognitiveservices.speech as speechsdk
|
import azure.cognitiveservices.speech as speechsdk
|
||||||
|
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
from common.tmp_dir import TmpDir
|
from common.tmp_dir import TmpDir
|
||||||
from voice.voice import Voice
|
|
||||||
from config import conf
|
from config import conf
|
||||||
|
from voice.voice import Voice
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Azure voice
|
Azure voice
|
||||||
主目录设置文件中需填写azure_voice_api_key和azure_voice_region
|
主目录设置文件中需填写azure_voice_api_key和azure_voice_region
|
||||||
@@ -19,50 +21,68 @@ Azure voice
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class AzureVoice(Voice):
|
|
||||||
|
|
||||||
|
class AzureVoice(Voice):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
curdir = os.path.dirname(__file__)
|
curdir = os.path.dirname(__file__)
|
||||||
config_path = os.path.join(curdir, "config.json")
|
config_path = os.path.join(curdir, "config.json")
|
||||||
config = None
|
config = None
|
||||||
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件
|
if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件
|
||||||
config = { "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", "speech_recognition_language": "zh-CN"}
|
config = {
|
||||||
|
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
|
||||||
|
"speech_recognition_language": "zh-CN",
|
||||||
|
}
|
||||||
with open(config_path, "w") as fw:
|
with open(config_path, "w") as fw:
|
||||||
json.dump(config, fw, indent=4)
|
json.dump(config, fw, indent=4)
|
||||||
else:
|
else:
|
||||||
with open(config_path, "r") as fr:
|
with open(config_path, "r") as fr:
|
||||||
config = json.load(fr)
|
config = json.load(fr)
|
||||||
self.api_key = conf().get('azure_voice_api_key')
|
self.api_key = conf().get("azure_voice_api_key")
|
||||||
self.api_region = conf().get('azure_voice_region')
|
self.api_region = conf().get("azure_voice_region")
|
||||||
self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
|
self.speech_config = speechsdk.SpeechConfig(
|
||||||
self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"]
|
subscription=self.api_key, region=self.api_region
|
||||||
self.speech_config.speech_recognition_language = config["speech_recognition_language"]
|
)
|
||||||
|
self.speech_config.speech_synthesis_voice_name = config[
|
||||||
|
"speech_synthesis_voice_name"
|
||||||
|
]
|
||||||
|
self.speech_config.speech_recognition_language = config[
|
||||||
|
"speech_recognition_language"
|
||||||
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn("AzureVoice init failed: %s, ignore " % e)
|
logger.warn("AzureVoice init failed: %s, ignore " % e)
|
||||||
|
|
||||||
def voiceToText(self, voice_file):
|
def voiceToText(self, voice_file):
|
||||||
audio_config = speechsdk.AudioConfig(filename=voice_file)
|
audio_config = speechsdk.AudioConfig(filename=voice_file)
|
||||||
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
|
speech_recognizer = speechsdk.SpeechRecognizer(
|
||||||
|
speech_config=self.speech_config, audio_config=audio_config
|
||||||
|
)
|
||||||
result = speech_recognizer.recognize_once()
|
result = speech_recognizer.recognize_once()
|
||||||
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
|
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
|
||||||
logger.info('[Azure] voiceToText voice file name={} text={}'.format(voice_file, result.text))
|
logger.info(
|
||||||
|
"[Azure] voiceToText voice file name={} text={}".format(
|
||||||
|
voice_file, result.text
|
||||||
|
)
|
||||||
|
)
|
||||||
reply = Reply(ReplyType.TEXT, result.text)
|
reply = Reply(ReplyType.TEXT, result.text)
|
||||||
else:
|
else:
|
||||||
logger.error('[Azure] voiceToText error, result={}'.format(result))
|
logger.error("[Azure] voiceToText error, result={}".format(result))
|
||||||
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
|
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def textToVoice(self, text):
|
def textToVoice(self, text):
|
||||||
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
|
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
|
||||||
audio_config = speechsdk.AudioConfig(filename=fileName)
|
audio_config = speechsdk.AudioConfig(filename=fileName)
|
||||||
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
|
speech_synthesizer = speechsdk.SpeechSynthesizer(
|
||||||
|
speech_config=self.speech_config, audio_config=audio_config
|
||||||
|
)
|
||||||
result = speech_synthesizer.speak_text(text)
|
result = speech_synthesizer.speak_text(text)
|
||||||
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
|
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
|
||||||
logger.info(
|
logger.info(
|
||||||
'[Azure] textToVoice text={} voice file name={}'.format(text, fileName))
|
"[Azure] textToVoice text={} voice file name={}".format(text, fileName)
|
||||||
|
)
|
||||||
reply = Reply(ReplyType.VOICE, fileName)
|
reply = Reply(ReplyType.VOICE, fileName)
|
||||||
else:
|
else:
|
||||||
logger.error('[Azure] textToVoice error, result={}'.format(result))
|
logger.error("[Azure] textToVoice error, result={}".format(result))
|
||||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
||||||
return reply
|
return reply
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
{
|
{
|
||||||
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
|
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
|
||||||
"speech_recognition_language": "zh-CN"
|
"speech_recognition_language": "zh-CN"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ dev_pid 必填 语言选择,填写语言对应的dev_pid值
|
|||||||
|
|
||||||
2、对于def textToVoice(self, text)函数中调用的百度语音合成API,中接口调用synthesis(参数)在本目录下的`config.json`文件中进行配置。
|
2、对于def textToVoice(self, text)函数中调用的百度语音合成API,中接口调用synthesis(参数)在本目录下的`config.json`文件中进行配置。
|
||||||
参数 可需 描述
|
参数 可需 描述
|
||||||
tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节
|
tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节
|
||||||
lan 必填 固定值zh。语言选择,目前只有中英文混合模式,填写固定值zh
|
lan 必填 固定值zh。语言选择,目前只有中英文混合模式,填写固定值zh
|
||||||
spd 选填 语速,取值0-15,默认为5中语速
|
spd 选填 语速,取值0-15,默认为5中语速
|
||||||
pit 选填 音调,取值0-15,默认为5中语调
|
pit 选填 音调,取值0-15,默认为5中语调
|
||||||
@@ -40,14 +40,14 @@ aue 选填 3为mp3格式(默认); 4为pcm-16k;5为pcm-8k;6为wav
|
|||||||
|
|
||||||
关于per参数的说明,注意您购买的哪个音库,就填写哪个音库的参数,否则会报错。如果您购买的是基础音库,那么per参数只能填写0到4,如果您购买的是精品音库,那么per参数只能填写5003,5118,106,110,111,103,5其他的都会报错。
|
关于per参数的说明,注意您购买的哪个音库,就填写哪个音库的参数,否则会报错。如果您购买的是基础音库,那么per参数只能填写0到4,如果您购买的是精品音库,那么per参数只能填写5003,5118,106,110,111,103,5其他的都会报错。
|
||||||
### 配置文件
|
### 配置文件
|
||||||
|
|
||||||
将文件夹中`config.json.template`复制为`config.json`。
|
将文件夹中`config.json.template`复制为`config.json`。
|
||||||
|
|
||||||
``` json
|
``` json
|
||||||
{
|
{
|
||||||
"lang": "zh",
|
"lang": "zh",
|
||||||
"ctp": 1,
|
"ctp": 1,
|
||||||
"spd": 5,
|
"spd": 5,
|
||||||
"pit": 5,
|
"pit": 5,
|
||||||
"vol": 5,
|
"vol": 5,
|
||||||
"per": 0
|
"per": 0
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
baidu voice service
|
baidu voice service
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from aip import AipSpeech
|
from aip import AipSpeech
|
||||||
|
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
from common.tmp_dir import TmpDir
|
from common.tmp_dir import TmpDir
|
||||||
from voice.voice import Voice
|
|
||||||
from voice.audio_convert import get_pcm_from_wav
|
|
||||||
from config import conf
|
from config import conf
|
||||||
|
from voice.audio_convert import get_pcm_from_wav
|
||||||
|
from voice.voice import Voice
|
||||||
|
|
||||||
"""
|
"""
|
||||||
百度的语音识别API.
|
百度的语音识别API.
|
||||||
dev_pid:
|
dev_pid:
|
||||||
@@ -28,40 +30,37 @@ from config import conf
|
|||||||
|
|
||||||
|
|
||||||
class BaiduVoice(Voice):
|
class BaiduVoice(Voice):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
curdir = os.path.dirname(__file__)
|
curdir = os.path.dirname(__file__)
|
||||||
config_path = os.path.join(curdir, "config.json")
|
config_path = os.path.join(curdir, "config.json")
|
||||||
bconf = None
|
bconf = None
|
||||||
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件
|
if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件
|
||||||
bconf = { "lang": "zh", "ctp": 1, "spd": 5,
|
bconf = {"lang": "zh", "ctp": 1, "spd": 5, "pit": 5, "vol": 5, "per": 0}
|
||||||
"pit": 5, "vol": 5, "per": 0}
|
|
||||||
with open(config_path, "w") as fw:
|
with open(config_path, "w") as fw:
|
||||||
json.dump(bconf, fw, indent=4)
|
json.dump(bconf, fw, indent=4)
|
||||||
else:
|
else:
|
||||||
with open(config_path, "r") as fr:
|
with open(config_path, "r") as fr:
|
||||||
bconf = json.load(fr)
|
bconf = json.load(fr)
|
||||||
|
|
||||||
self.app_id = conf().get('baidu_app_id')
|
self.app_id = conf().get("baidu_app_id")
|
||||||
self.api_key = conf().get('baidu_api_key')
|
self.api_key = conf().get("baidu_api_key")
|
||||||
self.secret_key = conf().get('baidu_secret_key')
|
self.secret_key = conf().get("baidu_secret_key")
|
||||||
self.dev_id = conf().get('baidu_dev_pid')
|
self.dev_id = conf().get("baidu_dev_pid")
|
||||||
self.lang = bconf["lang"]
|
self.lang = bconf["lang"]
|
||||||
self.ctp = bconf["ctp"]
|
self.ctp = bconf["ctp"]
|
||||||
self.spd = bconf["spd"]
|
self.spd = bconf["spd"]
|
||||||
self.pit = bconf["pit"]
|
self.pit = bconf["pit"]
|
||||||
self.vol = bconf["vol"]
|
self.vol = bconf["vol"]
|
||||||
self.per = bconf["per"]
|
self.per = bconf["per"]
|
||||||
|
|
||||||
self.client = AipSpeech(self.app_id, self.api_key, self.secret_key)
|
self.client = AipSpeech(self.app_id, self.api_key, self.secret_key)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn("BaiduVoice init failed: %s, ignore " % e)
|
logger.warn("BaiduVoice init failed: %s, ignore " % e)
|
||||||
|
|
||||||
|
|
||||||
def voiceToText(self, voice_file):
|
def voiceToText(self, voice_file):
|
||||||
# 识别本地文件
|
# 识别本地文件
|
||||||
logger.debug('[Baidu] voice file name={}'.format(voice_file))
|
logger.debug("[Baidu] voice file name={}".format(voice_file))
|
||||||
pcm = get_pcm_from_wav(voice_file)
|
pcm = get_pcm_from_wav(voice_file)
|
||||||
res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id})
|
res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id})
|
||||||
if res["err_no"] == 0:
|
if res["err_no"] == 0:
|
||||||
@@ -72,21 +71,25 @@ class BaiduVoice(Voice):
|
|||||||
logger.info("百度语音识别出错了: {}".format(res["err_msg"]))
|
logger.info("百度语音识别出错了: {}".format(res["err_msg"]))
|
||||||
if res["err_msg"] == "request pv too much":
|
if res["err_msg"] == "request pv too much":
|
||||||
logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费")
|
logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费")
|
||||||
reply = Reply(ReplyType.ERROR,
|
reply = Reply(ReplyType.ERROR, "百度语音识别出错了;{0}".format(res["err_msg"]))
|
||||||
"百度语音识别出错了;{0}".format(res["err_msg"]))
|
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def textToVoice(self, text):
|
def textToVoice(self, text):
|
||||||
result = self.client.synthesis(text, self.lang, self.ctp, {
|
result = self.client.synthesis(
|
||||||
'spd': self.spd, 'pit': self.pit, 'vol': self.vol, 'per': self.per})
|
text,
|
||||||
|
self.lang,
|
||||||
|
self.ctp,
|
||||||
|
{"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per},
|
||||||
|
)
|
||||||
if not isinstance(result, dict):
|
if not isinstance(result, dict):
|
||||||
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
|
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
|
||||||
with open(fileName, 'wb') as f:
|
with open(fileName, "wb") as f:
|
||||||
f.write(result)
|
f.write(result)
|
||||||
logger.info(
|
logger.info(
|
||||||
'[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
|
"[Baidu] textToVoice text={} voice file name={}".format(text, fileName)
|
||||||
|
)
|
||||||
reply = Reply(ReplyType.VOICE, fileName)
|
reply = Reply(ReplyType.VOICE, fileName)
|
||||||
else:
|
else:
|
||||||
logger.error('[Baidu] textToVoice error={}'.format(result))
|
logger.error("[Baidu] textToVoice error={}".format(result))
|
||||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
||||||
return reply
|
return reply
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
{
|
{
|
||||||
"lang": "zh",
|
"lang": "zh",
|
||||||
"ctp": 1,
|
"ctp": 1,
|
||||||
"spd": 5,
|
"spd": 5,
|
||||||
"pit": 5,
|
"pit": 5,
|
||||||
"vol": 5,
|
"vol": 5,
|
||||||
"per": 0
|
"per": 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
google voice service
|
google voice service
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import speech_recognition
|
import speech_recognition
|
||||||
from gtts import gTTS
|
from gtts import gTTS
|
||||||
|
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
from common.tmp_dir import TmpDir
|
from common.tmp_dir import TmpDir
|
||||||
@@ -22,9 +23,12 @@ class GoogleVoice(Voice):
|
|||||||
with speech_recognition.AudioFile(voice_file) as source:
|
with speech_recognition.AudioFile(voice_file) as source:
|
||||||
audio = self.recognizer.record(source)
|
audio = self.recognizer.record(source)
|
||||||
try:
|
try:
|
||||||
text = self.recognizer.recognize_google(audio, language='zh-CN')
|
text = self.recognizer.recognize_google(audio, language="zh-CN")
|
||||||
logger.info(
|
logger.info(
|
||||||
'[Google] voiceToText text={} voice file name={}'.format(text, voice_file))
|
"[Google] voiceToText text={} voice file name={}".format(
|
||||||
|
text, voice_file
|
||||||
|
)
|
||||||
|
)
|
||||||
reply = Reply(ReplyType.TEXT, text)
|
reply = Reply(ReplyType.TEXT, text)
|
||||||
except speech_recognition.UnknownValueError:
|
except speech_recognition.UnknownValueError:
|
||||||
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
|
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
|
||||||
@@ -32,13 +36,15 @@ class GoogleVoice(Voice):
|
|||||||
reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e))
|
reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e))
|
||||||
finally:
|
finally:
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def textToVoice(self, text):
|
def textToVoice(self, text):
|
||||||
try:
|
try:
|
||||||
mp3File = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
|
mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
|
||||||
tts = gTTS(text=text, lang='zh')
|
tts = gTTS(text=text, lang="zh")
|
||||||
tts.save(mp3File)
|
tts.save(mp3File)
|
||||||
logger.info(
|
logger.info(
|
||||||
'[Google] textToVoice text={} voice file name={}'.format(text, mp3File))
|
"[Google] textToVoice text={} voice file name={}".format(text, mp3File)
|
||||||
|
)
|
||||||
reply = Reply(ReplyType.VOICE, mp3File)
|
reply = Reply(ReplyType.VOICE, mp3File)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
reply = Reply(ReplyType.ERROR, str(e))
|
reply = Reply(ReplyType.ERROR, str(e))
|
||||||
|
|||||||
@@ -1,29 +1,32 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
google voice service
|
google voice service
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from config import conf
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from config import conf
|
||||||
from voice.voice import Voice
|
from voice.voice import Voice
|
||||||
|
|
||||||
|
|
||||||
class OpenaiVoice(Voice):
|
class OpenaiVoice(Voice):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
openai.api_key = conf().get('open_ai_api_key')
|
openai.api_key = conf().get("open_ai_api_key")
|
||||||
|
|
||||||
def voiceToText(self, voice_file):
|
def voiceToText(self, voice_file):
|
||||||
logger.debug(
|
logger.debug("[Openai] voice file name={}".format(voice_file))
|
||||||
'[Openai] voice file name={}'.format(voice_file))
|
|
||||||
try:
|
try:
|
||||||
file = open(voice_file, "rb")
|
file = open(voice_file, "rb")
|
||||||
result = openai.Audio.transcribe("whisper-1", file)
|
result = openai.Audio.transcribe("whisper-1", file)
|
||||||
text = result["text"]
|
text = result["text"]
|
||||||
reply = Reply(ReplyType.TEXT, text)
|
reply = Reply(ReplyType.TEXT, text)
|
||||||
logger.info(
|
logger.info(
|
||||||
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file))
|
"[Openai] voiceToText text={} voice file name={}".format(
|
||||||
|
text, voice_file
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
reply = Reply(ReplyType.ERROR, str(e))
|
reply = Reply(ReplyType.ERROR, str(e))
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
pytts voice service (offline)
|
pytts voice service (offline)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import pyttsx3
|
import pyttsx3
|
||||||
|
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
from common.tmp_dir import TmpDir
|
from common.tmp_dir import TmpDir
|
||||||
@@ -16,20 +17,21 @@ class PyttsVoice(Voice):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 语速
|
# 语速
|
||||||
self.engine.setProperty('rate', 125)
|
self.engine.setProperty("rate", 125)
|
||||||
# 音量
|
# 音量
|
||||||
self.engine.setProperty('volume', 1.0)
|
self.engine.setProperty("volume", 1.0)
|
||||||
for voice in self.engine.getProperty('voices'):
|
for voice in self.engine.getProperty("voices"):
|
||||||
if "Chinese" in voice.name:
|
if "Chinese" in voice.name:
|
||||||
self.engine.setProperty('voice', voice.id)
|
self.engine.setProperty("voice", voice.id)
|
||||||
|
|
||||||
def textToVoice(self, text):
|
def textToVoice(self, text):
|
||||||
try:
|
try:
|
||||||
wavFile = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
|
wavFile = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
|
||||||
self.engine.save_to_file(text, wavFile)
|
self.engine.save_to_file(text, wavFile)
|
||||||
self.engine.runAndWait()
|
self.engine.runAndWait()
|
||||||
logger.info(
|
logger.info(
|
||||||
'[Pytts] textToVoice text={} voice file name={}'.format(text, wavFile))
|
"[Pytts] textToVoice text={} voice file name={}".format(text, wavFile)
|
||||||
|
)
|
||||||
reply = Reply(ReplyType.VOICE, wavFile)
|
reply = Reply(ReplyType.VOICE, wavFile)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
reply = Reply(ReplyType.ERROR, str(e))
|
reply = Reply(ReplyType.ERROR, str(e))
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
Voice service abstract class
|
Voice service abstract class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Voice(object):
|
class Voice(object):
|
||||||
def voiceToText(self, voice_file):
|
def voiceToText(self, voice_file):
|
||||||
"""
|
"""
|
||||||
@@ -13,4 +14,4 @@ class Voice(object):
|
|||||||
"""
|
"""
|
||||||
Send text to voice service and get voice
|
Send text to voice service and get voice
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -2,25 +2,31 @@
|
|||||||
voice factory
|
voice factory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def create_voice(voice_type):
|
def create_voice(voice_type):
|
||||||
"""
|
"""
|
||||||
create a voice instance
|
create a voice instance
|
||||||
:param voice_type: voice type code
|
:param voice_type: voice type code
|
||||||
:return: voice instance
|
:return: voice instance
|
||||||
"""
|
"""
|
||||||
if voice_type == 'baidu':
|
if voice_type == "baidu":
|
||||||
from voice.baidu.baidu_voice import BaiduVoice
|
from voice.baidu.baidu_voice import BaiduVoice
|
||||||
|
|
||||||
return BaiduVoice()
|
return BaiduVoice()
|
||||||
elif voice_type == 'google':
|
elif voice_type == "google":
|
||||||
from voice.google.google_voice import GoogleVoice
|
from voice.google.google_voice import GoogleVoice
|
||||||
|
|
||||||
return GoogleVoice()
|
return GoogleVoice()
|
||||||
elif voice_type == 'openai':
|
elif voice_type == "openai":
|
||||||
from voice.openai.openai_voice import OpenaiVoice
|
from voice.openai.openai_voice import OpenaiVoice
|
||||||
|
|
||||||
return OpenaiVoice()
|
return OpenaiVoice()
|
||||||
elif voice_type == 'pytts':
|
elif voice_type == "pytts":
|
||||||
from voice.pytts.pytts_voice import PyttsVoice
|
from voice.pytts.pytts_voice import PyttsVoice
|
||||||
|
|
||||||
return PyttsVoice()
|
return PyttsVoice()
|
||||||
elif voice_type == 'azure':
|
elif voice_type == "azure":
|
||||||
from voice.azure.azure_voice import AzureVoice
|
from voice.azure.azure_voice import AzureVoice
|
||||||
|
|
||||||
return AzureVoice()
|
return AzureVoice()
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|||||||
Reference in New Issue
Block a user