mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-19 09:07:02 +08:00
Merge pull request #442 from lanvent/dev
简易支持插件,添加sdwebui(novelai画图), godcmd(管理员指令增强)插件,Banwords(敏感词过滤)插件
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,3 +7,4 @@ config.json
|
||||
QR.png
|
||||
nohup.out
|
||||
tmp
|
||||
plugins.json
|
||||
7
app.py
7
app.py
@@ -4,14 +4,17 @@ import config
|
||||
from channel import channel_factory
|
||||
from common.log import logger
|
||||
|
||||
|
||||
from plugins import *
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
# load config
|
||||
config.load_config()
|
||||
|
||||
# create channel
|
||||
channel = channel_factory.create_channel("wx")
|
||||
channel_name='wx'
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name=='wx':
|
||||
PluginManager().load_plugins()
|
||||
|
||||
# startup channel
|
||||
channel.startup()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import requests
|
||||
from bot.bot import Bot
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
|
||||
# Baidu Unit对话接口 (可用, 但能力较弱)
|
||||
@@ -14,7 +15,8 @@ class BaiduUnitBot(Bot):
|
||||
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
||||
response = requests.post(url, data=post_data.encode(), headers=headers)
|
||||
if response:
|
||||
return response.json()['result']['context']['SYS_PRESUMED_HIST'][1]
|
||||
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
|
||||
return reply
|
||||
|
||||
def get_token(self):
|
||||
access_key = 'YOUR_ACCESS_KEY'
|
||||
|
||||
@@ -3,8 +3,12 @@ Auto-replay chat robot abstract class
|
||||
"""
|
||||
|
||||
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
|
||||
|
||||
class Bot(object):
|
||||
def reply(self, query, context=None):
|
||||
def reply(self, query, context : Context =None) -> Reply:
|
||||
"""
|
||||
bot auto-reply content
|
||||
:param req: received message
|
||||
|
||||
@@ -1,41 +1,42 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf, load_config
|
||||
from common.log import logger
|
||||
from common.expired_dict import ExpiredDict
|
||||
import openai
|
||||
import time
|
||||
|
||||
if conf().get('expires_in_seconds'):
|
||||
all_sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
||||
else:
|
||||
all_sessions = dict()
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
proxy = conf().get('proxy')
|
||||
self.sessions = SessionManager()
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if not context or not context.get('type') or context.get('type') == 'TEXT':
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
session_id = context.get('session_id') or context.get('from_user_id')
|
||||
session_id = context['session_id']
|
||||
reply = None
|
||||
if query == '#清除记忆':
|
||||
Session.clear_session(session_id)
|
||||
return '记忆已清除'
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
Session.clear_all_session()
|
||||
return '所有人记忆已清除'
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
elif query == '#更新配置':
|
||||
load_config()
|
||||
return '配置已更新'
|
||||
|
||||
session = Session.build_session_query(query, session_id)
|
||||
reply = Reply(ReplyType.INFO, '配置已更新')
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.build_session_query(query, session_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(session))
|
||||
|
||||
# if context.get('stream'):
|
||||
@@ -44,14 +45,29 @@ class ChatGPTBot(Bot):
|
||||
|
||||
reply_content = self.reply_text(session, session_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"]))
|
||||
if reply_content["completion_tokens"] > 0:
|
||||
Session.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
return reply_content["content"]
|
||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||
return self.create_img(query, 0)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session, session_id, retry_count=0) ->dict:
|
||||
def reply_text(self, session, session_id, retry_count=0) -> dict:
|
||||
'''
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
@@ -70,8 +86,8 @@ class ChatGPTBot(Bot):
|
||||
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
)
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
return {"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]['message']['content']}
|
||||
except openai.error.RateLimitError as e:
|
||||
# rate limit exception
|
||||
@@ -86,15 +102,15 @@ class ChatGPTBot(Bot):
|
||||
# api connection exception
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] APIConnection failed")
|
||||
return {"completion_tokens": 0, "content":"我连接不到你的网络"}
|
||||
return {"completion_tokens": 0, "content": "我连接不到你的网络"}
|
||||
except openai.error.Timeout as e:
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] Timeout")
|
||||
return {"completion_tokens": 0, "content":"我没有收到你的消息"}
|
||||
return {"completion_tokens": 0, "content": "我没有收到你的消息"}
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
logger.exception(e)
|
||||
Session.clear_session(session_id)
|
||||
self.sessions.clear_session(session_id)
|
||||
return {"completion_tokens": 0, "content": "请再问我一次吧"}
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
@@ -107,7 +123,7 @@ class ChatGPTBot(Bot):
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return image_url
|
||||
return True, image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
@@ -115,14 +131,21 @@ class ChatGPTBot(Bot):
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.create_img(query, retry_count+1)
|
||||
else:
|
||||
return "提问太快啦,请休息一下再问我吧"
|
||||
return False, "提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return None
|
||||
return False, str(e)
|
||||
|
||||
class Session(object):
|
||||
@staticmethod
|
||||
def build_session_query(query, session_id):
|
||||
|
||||
class SessionManager(object):
|
||||
def __init__(self):
|
||||
if conf().get('expires_in_seconds'):
|
||||
sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
||||
else:
|
||||
sessions = dict()
|
||||
self.sessions = sessions
|
||||
|
||||
def build_session_query(self, query, session_id):
|
||||
'''
|
||||
build query with conversation history
|
||||
e.g. [
|
||||
@@ -135,36 +158,33 @@ class Session(object):
|
||||
:param session_id: session id
|
||||
:return: query content with conversaction
|
||||
'''
|
||||
session = all_sessions.get(session_id, [])
|
||||
session = self.sessions.get(session_id, [])
|
||||
if len(session) == 0:
|
||||
system_prompt = conf().get("character_desc", "")
|
||||
system_item = {'role': 'system', 'content': system_prompt}
|
||||
session.append(system_item)
|
||||
all_sessions[session_id] = session
|
||||
self.sessions[session_id] = session
|
||||
user_item = {'role': 'user', 'content': query}
|
||||
session.append(user_item)
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
def save_session(answer, session_id, total_tokens):
|
||||
def save_session(self, answer, session_id, total_tokens):
|
||||
max_tokens = conf().get("conversation_max_tokens")
|
||||
if not max_tokens:
|
||||
# default 3000
|
||||
max_tokens = 1000
|
||||
max_tokens=int(max_tokens)
|
||||
max_tokens = int(max_tokens)
|
||||
|
||||
session = all_sessions.get(session_id)
|
||||
session = self.sessions.get(session_id)
|
||||
if session:
|
||||
# append conversation
|
||||
gpt_item = {'role': 'assistant', 'content': answer}
|
||||
session.append(gpt_item)
|
||||
|
||||
# discard exceed limit conversation
|
||||
Session.discard_exceed_conversation(session, max_tokens, total_tokens)
|
||||
|
||||
self.discard_exceed_conversation(session, max_tokens, total_tokens)
|
||||
|
||||
@staticmethod
|
||||
def discard_exceed_conversation(session, max_tokens, total_tokens):
|
||||
def discard_exceed_conversation(self, session, max_tokens, total_tokens):
|
||||
dec_tokens = int(total_tokens)
|
||||
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
|
||||
while dec_tokens > max_tokens:
|
||||
@@ -173,13 +193,11 @@ class Session(object):
|
||||
session.pop(1)
|
||||
session.pop(1)
|
||||
else:
|
||||
break
|
||||
break
|
||||
dec_tokens = dec_tokens - max_tokens
|
||||
|
||||
@staticmethod
|
||||
def clear_session(session_id):
|
||||
all_sessions[session_id] = []
|
||||
def clear_session(self, session_id):
|
||||
self.sessions[session_id] = []
|
||||
|
||||
@staticmethod
|
||||
def clear_all_session():
|
||||
all_sessions.clear()
|
||||
def clear_all_session(self):
|
||||
self.sessions.clear()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import openai
|
||||
@@ -13,30 +15,31 @@ class OpenAIBot(Bot):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if not context or not context.get('type') or context.get('type') == 'TEXT':
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
from_user_id = context.get('from_user_id') or context.get('session_id')
|
||||
if query == '#清除记忆':
|
||||
Session.clear_session(from_user_id)
|
||||
return '记忆已清除'
|
||||
elif query == '#清除所有':
|
||||
Session.clear_all_session()
|
||||
return '所有人记忆已清除'
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
from_user_id = context['session_id']
|
||||
reply = None
|
||||
if query == '#清除记忆':
|
||||
Session.clear_session(from_user_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
Session.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
else:
|
||||
new_query = Session.build_session_query(query, from_user_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
||||
|
||||
new_query = Session.build_session_query(query, from_user_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
||||
|
||||
reply_content = self.reply_text(new_query, from_user_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
||||
if reply_content and query:
|
||||
Session.save_session(query, reply_content, from_user_id)
|
||||
return reply_content
|
||||
|
||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||
return self.create_img(query, 0)
|
||||
reply_content = self.reply_text(new_query, from_user_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
||||
if reply_content and query:
|
||||
Session.save_session(query, reply_content, from_user_id)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
return self.create_img(query, 0)
|
||||
|
||||
def reply_text(self, query, user_id, retry_count=0):
|
||||
try:
|
||||
|
||||
@@ -1,16 +1,42 @@
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from common.log import logger
|
||||
from bot import bot_factory
|
||||
from common.singleton import singleton
|
||||
from voice import voice_factory
|
||||
|
||||
|
||||
@singleton
|
||||
class Bridge(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
self.btype={
|
||||
"chat": "chatGPT",
|
||||
"voice_to_text": "openai",
|
||||
"text_to_voice": "baidu"
|
||||
}
|
||||
self.bots={}
|
||||
|
||||
def fetch_reply_content(self, query, context):
|
||||
return bot_factory.create_bot("chatGPT").reply(query, context)
|
||||
def get_bot(self,typename):
|
||||
if self.bots.get(typename) is None:
|
||||
logger.info("create bot {} for {}".format(self.btype[typename],typename))
|
||||
if typename == "text_to_voice":
|
||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
||||
elif typename == "voice_to_text":
|
||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
||||
elif typename == "chat":
|
||||
self.bots[typename] = bot_factory.create_bot(self.btype[typename])
|
||||
return self.bots[typename]
|
||||
|
||||
def get_bot_type(self,typename):
|
||||
return self.btype[typename]
|
||||
|
||||
def fetch_voice_to_text(self, voiceFile):
|
||||
return voice_factory.create_voice("openai").voiceToText(voiceFile)
|
||||
|
||||
def fetch_text_to_voice(self, text):
|
||||
return voice_factory.create_voice("baidu").textToVoice(text)
|
||||
def fetch_reply_content(self, query, context : Context) -> Reply:
|
||||
return self.get_bot("chat").reply(query, context)
|
||||
|
||||
def fetch_voice_to_text(self, voiceFile) -> Reply:
|
||||
return self.get_bot("voice_to_text").voiceToText(voiceFile)
|
||||
|
||||
def fetch_text_to_voice(self, text) -> Reply:
|
||||
return self.get_bot("text_to_voice").textToVoice(text)
|
||||
|
||||
|
||||
42
bridge/context.py
Normal file
42
bridge/context.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class ContextType (Enum):
|
||||
TEXT = 1 # 文本消息
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE_CREATE = 3 # 创建图片命令
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
class Context:
|
||||
def __init__(self, type : ContextType = None , content = None, kwargs = dict()):
|
||||
self.type = type
|
||||
self.content = content
|
||||
self.kwargs = kwargs
|
||||
def __getitem__(self, key):
|
||||
if key == 'type':
|
||||
return self.type
|
||||
elif key == 'content':
|
||||
return self.content
|
||||
else:
|
||||
return self.kwargs[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key == 'type':
|
||||
self.type = value
|
||||
elif key == 'content':
|
||||
self.content = value
|
||||
else:
|
||||
self.kwargs[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
if key == 'type':
|
||||
self.type = None
|
||||
elif key == 'content':
|
||||
self.content = None
|
||||
else:
|
||||
del self.kwargs[key]
|
||||
|
||||
def __str__(self):
|
||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
||||
22
bridge/reply.py
Normal file
22
bridge/reply.py
Normal file
@@ -0,0 +1,22 @@
|
||||
|
||||
# encoding:utf-8
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class ReplyType(Enum):
|
||||
TEXT = 1 # 文本
|
||||
VOICE = 2 # 音频文件
|
||||
IMAGE = 3 # 图片文件
|
||||
IMAGE_URL = 4 # 图片URL
|
||||
|
||||
INFO = 9
|
||||
ERROR = 10
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
class Reply:
|
||||
def __init__(self, type : ReplyType = None , content = None):
|
||||
self.type = type
|
||||
self.content = content
|
||||
def __str__(self):
|
||||
return "Reply(type={}, content={})".format(self.type, self.content)
|
||||
@@ -3,6 +3,8 @@ Message sending channel abstract class
|
||||
"""
|
||||
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
|
||||
class Channel(object):
|
||||
def startup(self):
|
||||
@@ -27,11 +29,11 @@ class Channel(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_reply_content(self, query, context=None):
|
||||
def build_reply_content(self, query, context : Context=None) -> Reply:
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
|
||||
def build_voice_to_text(self, voice_file):
|
||||
def build_voice_to_text(self, voice_file) -> Reply:
|
||||
return Bridge().fetch_voice_to_text(voice_file)
|
||||
|
||||
def build_text_to_voice(self, text):
|
||||
def build_text_to_voice(self, text) -> Reply:
|
||||
return Bridge().fetch_text_to_voice(text)
|
||||
|
||||
@@ -7,16 +7,24 @@ wechat channel
|
||||
import itchat
|
||||
import json
|
||||
from itchat.content import *
|
||||
from bridge.reply import *
|
||||
from bridge.context import *
|
||||
from channel.channel import Channel
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from config import conf
|
||||
from plugins import *
|
||||
|
||||
import requests
|
||||
import io
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
def thread_pool_callback(worker):
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
|
||||
@itchat.msg_register(TEXT)
|
||||
def handler_single_msg(msg):
|
||||
@@ -47,62 +55,52 @@ class WechatChannel(Channel):
|
||||
# start message listener
|
||||
itchat.run()
|
||||
|
||||
# handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context
|
||||
# context是一个字典,包含了消息的所有信息,包括以下key
|
||||
# type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE
|
||||
# content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
|
||||
# session_id: 会话id
|
||||
# isgroup: 是否是群聊
|
||||
# msg: 原始消息对象
|
||||
# receiver: 需要回复的对象
|
||||
|
||||
def handle_voice(self, msg):
|
||||
if conf().get('speech_recognition') != True :
|
||||
if conf().get('speech_recognition') != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: " + msg['FileName'])
|
||||
thread_pool.submit(self._do_handle_voice, msg)
|
||||
|
||||
def _do_handle_voice(self, msg):
|
||||
from_user_id = msg['FromUserName']
|
||||
other_user_id = msg['User']['UserName']
|
||||
if from_user_id == other_user_id:
|
||||
file_name = TmpDir().path() + msg['FileName']
|
||||
msg.download(file_name)
|
||||
query = super().build_voice_to_text(file_name)
|
||||
if conf().get('voice_reply_voice'):
|
||||
self._do_send_voice(query, from_user_id)
|
||||
else:
|
||||
self._do_send_text(query, from_user_id)
|
||||
context = Context(ContextType.VOICE,msg['FileName'])
|
||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
def handle_text(self, msg):
|
||||
logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
|
||||
content = msg['Text']
|
||||
self._handle_single_msg(msg, content)
|
||||
|
||||
def _handle_single_msg(self, msg, content):
|
||||
from_user_id = msg['FromUserName']
|
||||
to_user_id = msg['ToUserName'] # 接收人id
|
||||
other_user_id = msg['User']['UserName'] # 对手方id
|
||||
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
|
||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
|
||||
if "」\n- - - - - - - - - - - - - - -" in content:
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return
|
||||
if from_user_id == other_user_id and match_prefix is not None:
|
||||
# 好友向自己发送消息
|
||||
if match_prefix != '':
|
||||
str_list = content.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
content = str_list[1].strip()
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
else:
|
||||
return
|
||||
context = Context()
|
||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
|
||||
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
thread_pool.submit(self._do_send_img, content, from_user_id)
|
||||
else :
|
||||
thread_pool.submit(self._do_send_text, content, from_user_id)
|
||||
elif to_user_id == other_user_id and match_prefix:
|
||||
# 自己给好友发送消息
|
||||
str_list = content.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
content = str_list[1].strip()
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
thread_pool.submit(self._do_send_img, content, to_user_id)
|
||||
else:
|
||||
thread_pool.submit(self._do_send_text, content, to_user_id)
|
||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, '', 1).strip()
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
|
||||
context.content = content
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
def handle_group(self, msg):
|
||||
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
|
||||
@@ -122,100 +120,128 @@ class WechatChannel(Channel):
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return ""
|
||||
config = conf()
|
||||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \
|
||||
or self.check_contain(origin_content, config.get('group_chat_keyword'))
|
||||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \
|
||||
or check_contain(origin_content, config.get('group_chat_keyword'))
|
||||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
|
||||
context = Context()
|
||||
context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id}
|
||||
|
||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
thread_pool.submit(self._do_send_img, content, group_id)
|
||||
content = content.replace(img_match_prefix, '', 1).strip()
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
thread_pool.submit(self._do_send_group, content, msg)
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content
|
||||
|
||||
def send(self, msg, receiver):
|
||||
itchat.send(msg, toUserName=receiver)
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver))
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
if ('ALL_GROUP' in group_chat_in_one_session or
|
||||
group_name in group_chat_in_one_session or
|
||||
check_contain(group_name, group_chat_in_one_session)):
|
||||
context['session_id'] = group_id
|
||||
else:
|
||||
context['session_id'] = msg['ActualUserName']
|
||||
|
||||
def _do_send_voice(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = dict()
|
||||
context['from_user_id'] = reply_user_id
|
||||
reply_text = super().build_reply_content(query, context)
|
||||
if reply_text:
|
||||
replyFile = super().build_text_to_voice(reply_text)
|
||||
itchat.send_file(replyFile, toUserName=reply_user_id)
|
||||
logger.info('[WX] sendFile={}, receiver={}'.format(replyFile, reply_user_id))
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
def _do_send_text(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = dict()
|
||||
context['session_id'] = reply_user_id
|
||||
reply_text = super().build_reply_content(query, context)
|
||||
if reply_text:
|
||||
self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def _do_send_img(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = dict()
|
||||
context['type'] = 'IMAGE_CREATE'
|
||||
img_url = super().build_reply_content(query, context)
|
||||
if not img_url:
|
||||
return
|
||||
|
||||
# 图片下载
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply : Reply, receiver):
|
||||
if reply.type == ReplyType.TEXT:
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
itchat.send_file(reply.content, toUserName=receiver)
|
||||
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info('[WX] sendImage, receiver={}'.format(receiver))
|
||||
|
||||
# 图片发送
|
||||
itchat.send_image(image_storage, reply_user_id)
|
||||
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
|
||||
def handle(self, context):
|
||||
reply = Reply()
|
||||
|
||||
def _do_send_group(self, query, msg):
|
||||
if not query:
|
||||
return
|
||||
context = dict()
|
||||
group_name = msg['User']['NickName']
|
||||
group_id = msg['User']['UserName']
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
if ('ALL_GROUP' in group_chat_in_one_session or \
|
||||
group_name in group_chat_in_one_session or \
|
||||
self.check_contain(group_name, group_chat_in_one_session)):
|
||||
context['session_id'] = group_id
|
||||
else:
|
||||
context['session_id'] = msg['ActualUserName']
|
||||
reply_text = super().build_reply_content(query, context)
|
||||
if reply_text:
|
||||
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
|
||||
self.send(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)
|
||||
logger.debug('[WX] ready to handle context: {}'.format(context))
|
||||
|
||||
# reply的构建步骤
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
if not e_context.is_pass():
|
||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
elif context.type == ContextType.VOICE:
|
||||
msg = context['msg']
|
||||
file_name = TmpDir().path() + context.content
|
||||
msg.download(file_name)
|
||||
reply = super().build_voice_to_text(file_name)
|
||||
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
|
||||
context.content = reply.content # 语音转文字后,将文字内容作为新的context
|
||||
context.type = ContextType.TEXT
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
if reply.type == ReplyType.TEXT:
|
||||
if conf().get('voice_reply_voice'):
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
else:
|
||||
logger.error('[WX] unknown context type: {}'.format(context.type))
|
||||
return
|
||||
|
||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
|
||||
|
||||
# reply的包装步骤
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
|
||||
reply=e_context['reply']
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if context['isgroup']:
|
||||
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
|
||||
reply.content = reply_text
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = str(reply.type)+":\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
||||
pass
|
||||
else:
|
||||
logger.error('[WX] unknown reply type: {}'.format(reply.type))
|
||||
return
|
||||
|
||||
# reply的发送步骤
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
|
||||
reply=e_context['reply']
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
|
||||
self.send(reply, context['receiver'])
|
||||
|
||||
|
||||
def check_prefix(self, content, prefix_list):
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
def check_prefix(content, prefix_list):
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
|
||||
def check_contain(content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
|
||||
|
||||
def check_contain(self, content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
|
||||
@@ -11,6 +11,7 @@ import time
|
||||
import asyncio
|
||||
import requests
|
||||
from typing import Optional, Union
|
||||
from bridge.context import Context, ContextType
|
||||
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore
|
||||
from wechaty import Wechaty, Contact
|
||||
from wechaty.user import Message, Room, MiniProgram, UrlLink
|
||||
@@ -127,9 +128,9 @@ class WechatyChannel(Channel):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = dict()
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context['session_id'] = reply_user_id
|
||||
reply_text = super().build_reply_content(query, context)
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
|
||||
except Exception as e:
|
||||
@@ -139,9 +140,8 @@ class WechatyChannel(Channel):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = dict()
|
||||
context['type'] = 'IMAGE_CREATE'
|
||||
img_url = super().build_reply_content(query, context)
|
||||
context = Context(ContextType.IMAGE_CREATE, query)
|
||||
img_url = super().build_reply_content(query, context).content
|
||||
if not img_url:
|
||||
return
|
||||
# 图片下载
|
||||
@@ -162,7 +162,7 @@ class WechatyChannel(Channel):
|
||||
async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name):
|
||||
if not query:
|
||||
return
|
||||
context = dict()
|
||||
context = Context(ContextType.TEXT, query)
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
if ('ALL_GROUP' in group_chat_in_one_session or \
|
||||
group_name in group_chat_in_one_session or \
|
||||
@@ -170,7 +170,7 @@ class WechatyChannel(Channel):
|
||||
context['session_id'] = str(group_id)
|
||||
else:
|
||||
context['session_id'] = str(group_id) + '-' + str(group_user_id)
|
||||
reply_text = super().build_reply_content(query, context)
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
reply_text = '@' + group_user_name + ' ' + reply_text.strip()
|
||||
await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)
|
||||
@@ -179,9 +179,8 @@ class WechatyChannel(Channel):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = dict()
|
||||
context['type'] = 'IMAGE_CREATE'
|
||||
img_url = super().build_reply_content(query, context)
|
||||
context = Context(ContextType.IMAGE_CREATE, query)
|
||||
img_url = super().build_reply_content(query, context).content
|
||||
if not img_url:
|
||||
return
|
||||
# 图片发送
|
||||
|
||||
9
common/singleton.py
Normal file
9
common/singleton.py
Normal file
@@ -0,0 +1,9 @@
|
||||
def singleton(cls):
|
||||
instances = {}
|
||||
|
||||
def get_instance(*args, **kwargs):
|
||||
if cls not in instances:
|
||||
instances[cls] = cls(*args, **kwargs)
|
||||
return instances[cls]
|
||||
|
||||
return get_instance
|
||||
65
common/sorted_dict.py
Normal file
65
common/sorted_dict.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import heapq
|
||||
|
||||
|
||||
class SortedDict(dict):
|
||||
def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False):
|
||||
if init_dict is None:
|
||||
init_dict = []
|
||||
if isinstance(init_dict, dict):
|
||||
init_dict = init_dict.items()
|
||||
self.sort_func = sort_func
|
||||
self.sorted_keys = None
|
||||
self.reverse = reverse
|
||||
self.heap = []
|
||||
for k, v in init_dict:
|
||||
self[k] = v
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key in self:
|
||||
super().__setitem__(key, value)
|
||||
for i, (priority, k) in enumerate(self.heap):
|
||||
if k == key:
|
||||
self.heap[i] = (self.sort_func(key, value), key)
|
||||
heapq.heapify(self.heap)
|
||||
break
|
||||
self.sorted_keys = None
|
||||
else:
|
||||
super().__setitem__(key, value)
|
||||
heapq.heappush(self.heap, (self.sort_func(key, value), key))
|
||||
self.sorted_keys = None
|
||||
|
||||
def __delitem__(self, key):
|
||||
super().__delitem__(key)
|
||||
for i, (priority, k) in enumerate(self.heap):
|
||||
if k == key:
|
||||
del self.heap[i]
|
||||
heapq.heapify(self.heap)
|
||||
break
|
||||
self.sorted_keys = None
|
||||
|
||||
def keys(self):
|
||||
if self.sorted_keys is None:
|
||||
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
|
||||
return self.sorted_keys
|
||||
|
||||
def items(self):
|
||||
if self.sorted_keys is None:
|
||||
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
|
||||
sorted_items = [(k, self[k]) for k in self.sorted_keys]
|
||||
return sorted_items
|
||||
|
||||
def _update_heap(self, key):
|
||||
for i, (priority, k) in enumerate(self.heap):
|
||||
if k == key:
|
||||
new_priority = self.sort_func(key, self[key])
|
||||
if new_priority != priority:
|
||||
self.heap[i] = (new_priority, key)
|
||||
heapq.heapify(self.heap)
|
||||
self.sorted_keys = None
|
||||
break
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.keys())
|
||||
|
||||
def __repr__(self):
|
||||
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})'
|
||||
9
plugins/__init__.py
Normal file
9
plugins/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .plugin_manager import PluginManager
|
||||
from .event import *
|
||||
from .plugin import *
|
||||
|
||||
instance = PluginManager()
|
||||
|
||||
register = instance.register
|
||||
# load_plugins = instance.load_plugins
|
||||
# emit_event = instance.emit_event
|
||||
1
plugins/banwords/.gitignore
vendored
Normal file
1
plugins/banwords/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
banwords.txt
|
||||
9
plugins/banwords/README.md
Normal file
9
plugins/banwords/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
### 说明
|
||||
简易的敏感词插件,暂不支持分词,请自行导入词库到插件文件夹中的`banwords.txt`,每行一个词,一个参考词库是[1](https://github.com/cjh0613/tencent-sensitive-words/blob/main/sensitive_words_lines.txt)。
|
||||
|
||||
`config.json`中能够填写默认的处理行为,目前行为有:
|
||||
- `ignore` : 无视这条消息。
|
||||
- `replace` : 将消息中的敏感词替换成"*",并回复违规。
|
||||
|
||||
### 致谢
|
||||
搜索功能实现来自https://github.com/toolgood/ToolGood.Words
|
||||
250
plugins/banwords/WordsSearch.py
Normal file
250
plugins/banwords/WordsSearch.py
Normal file
@@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# ToolGood.Words.WordsSearch.py
|
||||
# 2020, Lin Zhijun, https://github.com/toolgood/ToolGood.Words
|
||||
# Licensed under the Apache License 2.0
|
||||
# 更新日志
|
||||
# 2020.04.06 第一次提交
|
||||
# 2020.05.16 修改,支持大于0xffff的字符
|
||||
|
||||
__all__ = ['WordsSearch']
|
||||
__author__ = 'Lin Zhijun'
|
||||
__date__ = '2020.05.16'
|
||||
|
||||
class TrieNode():
|
||||
def __init__(self):
|
||||
self.Index = 0
|
||||
self.Index = 0
|
||||
self.Layer = 0
|
||||
self.End = False
|
||||
self.Char = ''
|
||||
self.Results = []
|
||||
self.m_values = {}
|
||||
self.Failure = None
|
||||
self.Parent = None
|
||||
|
||||
def Add(self,c):
|
||||
if c in self.m_values :
|
||||
return self.m_values[c]
|
||||
node = TrieNode()
|
||||
node.Parent = self
|
||||
node.Char = c
|
||||
self.m_values[c] = node
|
||||
return node
|
||||
|
||||
def SetResults(self,index):
|
||||
if (self.End == False):
|
||||
self.End = True
|
||||
self.Results.append(index)
|
||||
|
||||
class TrieNode2():
|
||||
def __init__(self):
|
||||
self.End = False
|
||||
self.Results = []
|
||||
self.m_values = {}
|
||||
self.minflag = 0xffff
|
||||
self.maxflag = 0
|
||||
|
||||
def Add(self,c,node3):
|
||||
if (self.minflag > c):
|
||||
self.minflag = c
|
||||
if (self.maxflag < c):
|
||||
self.maxflag = c
|
||||
self.m_values[c] = node3
|
||||
|
||||
def SetResults(self,index):
|
||||
if (self.End == False) :
|
||||
self.End = True
|
||||
if (index in self.Results )==False :
|
||||
self.Results.append(index)
|
||||
|
||||
def HasKey(self,c):
|
||||
return c in self.m_values
|
||||
|
||||
|
||||
def TryGetValue(self,c):
|
||||
if (self.minflag <= c and self.maxflag >= c):
|
||||
if c in self.m_values:
|
||||
return self.m_values[c]
|
||||
return None
|
||||
|
||||
|
||||
class WordsSearch():
|
||||
def __init__(self):
|
||||
self._first = {}
|
||||
self._keywords = []
|
||||
self._indexs=[]
|
||||
|
||||
def SetKeywords(self,keywords):
|
||||
self._keywords = keywords
|
||||
self._indexs=[]
|
||||
for i in range(len(keywords)):
|
||||
self._indexs.append(i)
|
||||
|
||||
root = TrieNode()
|
||||
allNodeLayer={}
|
||||
|
||||
for i in range(len(self._keywords)): # for (i = 0; i < _keywords.length; i++)
|
||||
p = self._keywords[i]
|
||||
nd = root
|
||||
for j in range(len(p)): # for (j = 0; j < p.length; j++)
|
||||
nd = nd.Add(ord(p[j]))
|
||||
if (nd.Layer == 0):
|
||||
nd.Layer = j + 1
|
||||
if nd.Layer in allNodeLayer:
|
||||
allNodeLayer[nd.Layer].append(nd)
|
||||
else:
|
||||
allNodeLayer[nd.Layer]=[]
|
||||
allNodeLayer[nd.Layer].append(nd)
|
||||
nd.SetResults(i)
|
||||
|
||||
|
||||
allNode = []
|
||||
allNode.append(root)
|
||||
for key in allNodeLayer.keys():
|
||||
for nd in allNodeLayer[key]:
|
||||
allNode.append(nd)
|
||||
allNodeLayer=None
|
||||
|
||||
for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++)
|
||||
if i==0 :
|
||||
continue
|
||||
nd=allNode[i]
|
||||
nd.Index = i
|
||||
r = nd.Parent.Failure
|
||||
c = nd.Char
|
||||
while (r != None and (c in r.m_values)==False):
|
||||
r = r.Failure
|
||||
if (r == None):
|
||||
nd.Failure = root
|
||||
else:
|
||||
nd.Failure = r.m_values[c]
|
||||
for key2 in nd.Failure.Results :
|
||||
nd.SetResults(key2)
|
||||
root.Failure = root
|
||||
|
||||
allNode2 = []
|
||||
for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++)
|
||||
allNode2.append( TrieNode2())
|
||||
|
||||
for i in range(len(allNode2)): # for (i = 0; i < allNode2.length; i++)
|
||||
oldNode = allNode[i]
|
||||
newNode = allNode2[i]
|
||||
|
||||
for key in oldNode.m_values :
|
||||
index = oldNode.m_values[key].Index
|
||||
newNode.Add(key, allNode2[index])
|
||||
|
||||
for index in range(len(oldNode.Results)): # for (index = 0; index < oldNode.Results.length; index++)
|
||||
item = oldNode.Results[index]
|
||||
newNode.SetResults(item)
|
||||
|
||||
oldNode=oldNode.Failure
|
||||
while oldNode != root:
|
||||
for key in oldNode.m_values :
|
||||
if (newNode.HasKey(key) == False):
|
||||
index = oldNode.m_values[key].Index
|
||||
newNode.Add(key, allNode2[index])
|
||||
for index in range(len(oldNode.Results)):
|
||||
item = oldNode.Results[index]
|
||||
newNode.SetResults(item)
|
||||
oldNode=oldNode.Failure
|
||||
allNode = None
|
||||
root = None
|
||||
|
||||
# first = []
|
||||
# for index in range(65535):# for (index = 0; index < 0xffff; index++)
|
||||
# first.append(None)
|
||||
|
||||
# for key in allNode2[0].m_values :
|
||||
# first[key] = allNode2[0].m_values[key]
|
||||
|
||||
self._first = allNode2[0]
|
||||
|
||||
|
||||
def FindFirst(self,text):
|
||||
ptr = None
|
||||
for index in range(len(text)): # for (index = 0; index < text.length; index++)
|
||||
t =ord(text[index]) # text.charCodeAt(index)
|
||||
tn = None
|
||||
if (ptr == None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
else:
|
||||
tn = ptr.TryGetValue(t)
|
||||
if (tn==None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
|
||||
|
||||
if (tn != None):
|
||||
if (tn.End):
|
||||
item = tn.Results[0]
|
||||
keyword = self._keywords[item]
|
||||
return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] }
|
||||
ptr = tn
|
||||
return None
|
||||
|
||||
def FindAll(self,text):
|
||||
ptr = None
|
||||
list = []
|
||||
|
||||
for index in range(len(text)): # for (index = 0; index < text.length; index++)
|
||||
t =ord(text[index]) # text.charCodeAt(index)
|
||||
tn = None
|
||||
if (ptr == None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
else:
|
||||
tn = ptr.TryGetValue(t)
|
||||
if (tn==None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
|
||||
|
||||
if (tn != None):
|
||||
if (tn.End):
|
||||
for j in range(len(tn.Results)): # for (j = 0; j < tn.Results.length; j++)
|
||||
item = tn.Results[j]
|
||||
keyword = self._keywords[item]
|
||||
list.append({ "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] })
|
||||
ptr = tn
|
||||
return list
|
||||
|
||||
|
||||
def ContainsAny(self,text):
|
||||
ptr = None
|
||||
for index in range(len(text)): # for (index = 0; index < text.length; index++)
|
||||
t =ord(text[index]) # text.charCodeAt(index)
|
||||
tn = None
|
||||
if (ptr == None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
else:
|
||||
tn = ptr.TryGetValue(t)
|
||||
if (tn==None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
|
||||
if (tn != None):
|
||||
if (tn.End):
|
||||
return True
|
||||
ptr = tn
|
||||
return False
|
||||
|
||||
def Replace(self,text, replaceChar = '*'):
|
||||
result = list(text)
|
||||
|
||||
ptr = None
|
||||
for i in range(len(text)): # for (i = 0; i < text.length; i++)
|
||||
t =ord(text[i]) # text.charCodeAt(index)
|
||||
tn = None
|
||||
if (ptr == None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
else:
|
||||
tn = ptr.TryGetValue(t)
|
||||
if (tn==None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
|
||||
if (tn != None):
|
||||
if (tn.End):
|
||||
maxLength = len( self._keywords[tn.Results[0]])
|
||||
start = i + 1 - maxLength
|
||||
for j in range(start,i+1): # for (j = start; j <= i; j++)
|
||||
result[j] = replaceChar
|
||||
ptr = tn
|
||||
return ''.join(result)
|
||||
0
plugins/banwords/__init__.py
Normal file
0
plugins/banwords/__init__.py
Normal file
63
plugins/banwords/banwords.py
Normal file
63
plugins/banwords/banwords.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import os
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
import plugins
|
||||
from plugins import *
|
||||
from common.log import logger
|
||||
from .WordsSearch import WordsSearch
|
||||
|
||||
|
||||
@plugins.register(name="Banwords", desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent", desire_priority= 100)
|
||||
class Banwords(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
curdir=os.path.dirname(__file__)
|
||||
config_path=os.path.join(curdir,"config.json")
|
||||
conf=None
|
||||
if not os.path.exists(config_path):
|
||||
conf={"action":"ignore"}
|
||||
with open(config_path,"w") as f:
|
||||
json.dump(conf,f,indent=4)
|
||||
else:
|
||||
with open(config_path,"r") as f:
|
||||
conf=json.load(f)
|
||||
self.searchr = WordsSearch()
|
||||
self.action = conf["action"]
|
||||
banwords_path = os.path.join(curdir,"banwords.txt")
|
||||
with open(banwords_path, 'r', encoding='utf-8') as f:
|
||||
words=[]
|
||||
for line in f:
|
||||
word = line.strip()
|
||||
if word:
|
||||
words.append(word)
|
||||
self.searchr.SetKeywords(words)
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[Banwords] inited")
|
||||
except Exception as e:
|
||||
logger.error("Banwords init failed: %s" % e)
|
||||
|
||||
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
|
||||
if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]:
|
||||
return
|
||||
|
||||
content = e_context['context'].content
|
||||
logger.debug("[Banwords] on_handle_context. content: %s" % content)
|
||||
if self.action == "ignore":
|
||||
f = self.searchr.FindFirst(content)
|
||||
if f:
|
||||
logger.info("Banwords: %s" % f["Keyword"])
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
elif self.action == "replace":
|
||||
if self.searchr.ContainsAny(content):
|
||||
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content))
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
3
plugins/banwords/banwords.txt.template
Normal file
3
plugins/banwords/banwords.txt.template
Normal file
@@ -0,0 +1,3 @@
|
||||
nipples
|
||||
pennis
|
||||
法轮功
|
||||
3
plugins/banwords/config.json.template
Normal file
3
plugins/banwords/config.json.template
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"action": "ignore"
|
||||
}
|
||||
49
plugins/event.py
Normal file
49
plugins/event.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Event(Enum):
|
||||
# ON_RECEIVE_MESSAGE = 1 # 收到消息
|
||||
|
||||
ON_HANDLE_CONTEXT = 2 # 处理消息前
|
||||
"""
|
||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 }
|
||||
"""
|
||||
|
||||
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰
|
||||
"""
|
||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
|
||||
"""
|
||||
|
||||
ON_SEND_REPLY = 4 # 发送回复前
|
||||
"""
|
||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
|
||||
"""
|
||||
|
||||
# AFTER_SEND_REPLY = 5 # 发送回复后
|
||||
|
||||
|
||||
class EventAction(Enum):
|
||||
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
|
||||
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
|
||||
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑
|
||||
|
||||
|
||||
class EventContext:
|
||||
def __init__(self, event, econtext=dict()):
|
||||
self.event = event
|
||||
self.econtext = econtext
|
||||
self.action = EventAction.CONTINUE
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.econtext[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.econtext[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.econtext[key]
|
||||
|
||||
def is_pass(self):
|
||||
return self.action == EventAction.BREAK_PASS
|
||||
0
plugins/godcmd/__init__.py
Normal file
0
plugins/godcmd/__init__.py
Normal file
4
plugins/godcmd/config.json.template
Normal file
4
plugins/godcmd/config.json.template
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"password": "",
|
||||
"admin_users": []
|
||||
}
|
||||
289
plugins/godcmd/godcmd.py
Normal file
289
plugins/godcmd/godcmd.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from typing import Tuple
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import load_config
|
||||
import plugins
|
||||
from plugins import *
|
||||
from common.log import logger
|
||||
|
||||
# 定义指令集
|
||||
COMMANDS = {
|
||||
"help": {
|
||||
"alias": ["help", "帮助"],
|
||||
"desc": "打印指令集合",
|
||||
},
|
||||
"auth": {
|
||||
"alias": ["auth", "认证"],
|
||||
"args": ["口令"],
|
||||
"desc": "管理员认证",
|
||||
},
|
||||
# "id": {
|
||||
# "alias": ["id", "用户"],
|
||||
# "desc": "获取用户id", #目前无实际意义
|
||||
# },
|
||||
"reset": {
|
||||
"alias": ["reset", "重置会话"],
|
||||
"desc": "重置会话",
|
||||
},
|
||||
}
|
||||
|
||||
ADMIN_COMMANDS = {
|
||||
"resume": {
|
||||
"alias": ["resume", "恢复服务"],
|
||||
"desc": "恢复服务",
|
||||
},
|
||||
"stop": {
|
||||
"alias": ["stop", "暂停服务"],
|
||||
"desc": "暂停服务",
|
||||
},
|
||||
"reconf": {
|
||||
"alias": ["reconf", "重载配置"],
|
||||
"desc": "重载配置(不包含插件配置)",
|
||||
},
|
||||
"resetall": {
|
||||
"alias": ["resetall", "重置所有会话"],
|
||||
"desc": "重置所有会话",
|
||||
},
|
||||
"scanp": {
|
||||
"alias": ["scanp", "扫描插件"],
|
||||
"desc": "扫描插件目录是否有新插件",
|
||||
},
|
||||
"plist": {
|
||||
"alias": ["plist", "插件"],
|
||||
"desc": "打印当前插件列表",
|
||||
},
|
||||
"setpri": {
|
||||
"alias": ["setpri", "设置插件优先级"],
|
||||
"args": ["插件名", "优先级"],
|
||||
"desc": "设置指定插件的优先级,越大越优先",
|
||||
},
|
||||
"reloadp": {
|
||||
"alias": ["reloadp", "重载插件"],
|
||||
"args": ["插件名"],
|
||||
"desc": "重载指定插件配置",
|
||||
},
|
||||
"enablep": {
|
||||
"alias": ["enablep", "启用插件"],
|
||||
"args": ["插件名"],
|
||||
"desc": "启用指定插件",
|
||||
},
|
||||
"disablep": {
|
||||
"alias": ["disablep", "禁用插件"],
|
||||
"args": ["插件名"],
|
||||
"desc": "禁用指定插件",
|
||||
},
|
||||
"debug": {
|
||||
"alias": ["debug", "调试模式", "DEBUG"],
|
||||
"desc": "开启机器调试日志",
|
||||
},
|
||||
}
|
||||
# 定义帮助函数
|
||||
def get_help_text(isadmin, isgroup):
|
||||
help_text = "可用指令:\n"
|
||||
for cmd, info in COMMANDS.items():
|
||||
if cmd=="auth" and (isadmin or isgroup): # 群聊不可认证
|
||||
continue
|
||||
|
||||
alias=["#"+a for a in info['alias']]
|
||||
help_text += f"{','.join(alias)} "
|
||||
if 'args' in info:
|
||||
args=["{"+a+"}" for a in info['args']]
|
||||
help_text += f"{' '.join(args)} "
|
||||
help_text += f": {info['desc']}\n"
|
||||
if ADMIN_COMMANDS and isadmin:
|
||||
help_text += "\n管理员指令:\n"
|
||||
for cmd, info in ADMIN_COMMANDS.items():
|
||||
alias=["#"+a for a in info['alias']]
|
||||
help_text += f"{','.join(alias)} "
|
||||
help_text += f": {info['desc']}\n"
|
||||
return help_text
|
||||
|
||||
@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent", desire_priority= 999)
|
||||
class Godcmd(Plugin):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
curdir=os.path.dirname(__file__)
|
||||
config_path=os.path.join(curdir,"config.json")
|
||||
gconf=None
|
||||
if not os.path.exists(config_path):
|
||||
gconf={"password":"","admin_users":[]}
|
||||
with open(config_path,"w") as f:
|
||||
json.dump(gconf,f,indent=4)
|
||||
else:
|
||||
with open(config_path,"r") as f:
|
||||
gconf=json.load(f)
|
||||
|
||||
self.password = gconf["password"]
|
||||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用
|
||||
self.isrunning = True # 机器人是否运行中
|
||||
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[Godcmd] inited")
|
||||
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
context_type = e_context['context'].type
|
||||
if context_type != ContextType.TEXT:
|
||||
if not self.isrunning:
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
content = e_context['context'].content
|
||||
logger.debug("[Godcmd] on_handle_context. content: %s" % content)
|
||||
if content.startswith("#"):
|
||||
# msg = e_context['context']['msg']
|
||||
user = e_context['context']['receiver']
|
||||
session_id = e_context['context']['session_id']
|
||||
isgroup = e_context['context']['isgroup']
|
||||
bottype = Bridge().get_bot_type("chat")
|
||||
bot = Bridge().get_bot("chat")
|
||||
# 将命令和参数分割
|
||||
command_parts = content[1:].split(" ")
|
||||
cmd = command_parts[0]
|
||||
args = command_parts[1:]
|
||||
isadmin=False
|
||||
if user in self.admin_users:
|
||||
isadmin=True
|
||||
ok=False
|
||||
result="string"
|
||||
if any(cmd in info['alias'] for info in COMMANDS.values()):
|
||||
cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias'])
|
||||
if cmd == "auth":
|
||||
ok, result = self.authenticate(user, args, isadmin, isgroup)
|
||||
elif cmd == "help":
|
||||
ok, result = True, get_help_text(isadmin, isgroup)
|
||||
elif cmd == "id":
|
||||
ok, result = True, f"用户id=\n{user}"
|
||||
elif cmd == "reset":
|
||||
if bottype == "chatGPT":
|
||||
bot.sessions.clear_session(session_id)
|
||||
ok, result = True, "会话已重置"
|
||||
else:
|
||||
ok, result = False, "当前对话机器人不支持重置会话"
|
||||
logger.debug("[Godcmd] command: %s by %s" % (cmd, user))
|
||||
elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()):
|
||||
if isadmin:
|
||||
if isgroup:
|
||||
ok, result = False, "群聊不可执行管理员指令"
|
||||
else:
|
||||
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias'])
|
||||
if cmd == "stop":
|
||||
self.isrunning = False
|
||||
ok, result = True, "服务已暂停"
|
||||
elif cmd == "resume":
|
||||
self.isrunning = True
|
||||
ok, result = True, "服务已恢复"
|
||||
elif cmd == "reconf":
|
||||
load_config()
|
||||
ok, result = True, "配置已重载"
|
||||
elif cmd == "resetall":
|
||||
if bottype == "chatGPT":
|
||||
bot.sessions.clear_all_session()
|
||||
ok, result = True, "重置所有会话成功"
|
||||
else:
|
||||
ok, result = False, "当前对话机器人不支持重置会话"
|
||||
elif cmd == "debug":
|
||||
logger.setLevel('DEBUG')
|
||||
ok, result = True, "DEBUG模式已开启"
|
||||
elif cmd == "plist":
|
||||
plugins = PluginManager().list_plugins()
|
||||
ok = True
|
||||
result = "插件列表:\n"
|
||||
for name,plugincls in plugins.items():
|
||||
result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - "
|
||||
if plugincls.enabled:
|
||||
result += "已启用\n"
|
||||
else:
|
||||
result += "未启用\n"
|
||||
elif cmd == "scanp":
|
||||
new_plugins = PluginManager().scan_plugins()
|
||||
ok, result = True, "插件扫描完成"
|
||||
PluginManager().activate_plugins()
|
||||
if len(new_plugins) >0 :
|
||||
result += "\n发现新插件:\n"
|
||||
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
|
||||
else :
|
||||
result +=", 未发现新插件"
|
||||
elif cmd == "setpri":
|
||||
if len(args) != 2:
|
||||
ok, result = False, "请提供插件名和优先级"
|
||||
else:
|
||||
ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
|
||||
if ok:
|
||||
result = "插件" + args[0] + "优先级已设置为" + args[1]
|
||||
else:
|
||||
result = "插件不存在"
|
||||
elif cmd == "reloadp":
|
||||
if len(args) != 1:
|
||||
ok, result = False, "请提供插件名"
|
||||
else:
|
||||
ok = PluginManager().reload_plugin(args[0])
|
||||
if ok:
|
||||
result = "插件配置已重载"
|
||||
else:
|
||||
result = "插件不存在"
|
||||
elif cmd == "enablep":
|
||||
if len(args) != 1:
|
||||
ok, result = False, "请提供插件名"
|
||||
else:
|
||||
ok = PluginManager().enable_plugin(args[0])
|
||||
if ok:
|
||||
result = "插件已启用"
|
||||
else:
|
||||
result = "插件不存在"
|
||||
elif cmd == "disablep":
|
||||
if len(args) != 1:
|
||||
ok, result = False, "请提供插件名"
|
||||
else:
|
||||
ok = PluginManager().disable_plugin(args[0])
|
||||
if ok:
|
||||
result = "插件已禁用"
|
||||
else:
|
||||
result = "插件不存在"
|
||||
|
||||
logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user))
|
||||
else:
|
||||
ok, result = False, "需要管理员权限才能执行该指令"
|
||||
else:
|
||||
ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n"
|
||||
|
||||
reply = Reply()
|
||||
if ok:
|
||||
reply.type = ReplyType.INFO
|
||||
else:
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = result
|
||||
e_context['reply'] = reply
|
||||
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
elif not self.isrunning:
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
|
||||
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] :
|
||||
if isgroup:
|
||||
return False,"请勿在群聊中认证"
|
||||
|
||||
if isadmin:
|
||||
return False,"管理员账号无需认证"
|
||||
|
||||
if len(self.password) == 0:
|
||||
return False,"未设置口令,无法认证"
|
||||
|
||||
if len(args) != 1:
|
||||
return False,"请提供口令"
|
||||
|
||||
password = args[0]
|
||||
if password == self.password:
|
||||
self.admin_users.append(userid)
|
||||
return True,"认证成功"
|
||||
else:
|
||||
return False,"认证失败"
|
||||
|
||||
0
plugins/hello/__init__.py
Normal file
0
plugins/hello/__init__.py
Normal file
46
plugins/hello/hello.py
Normal file
46
plugins/hello/hello.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
import plugins
|
||||
from plugins import *
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent", desire_priority= -1)
|
||||
class Hello(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[Hello] inited")
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
|
||||
if e_context['context'].type != ContextType.TEXT:
|
||||
return
|
||||
|
||||
content = e_context['context'].content
|
||||
logger.debug("[Hello] on_handle_context. content: %s" % content)
|
||||
if content == "Hello":
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
msg = e_context['context']['msg']
|
||||
if e_context['context']['isgroup']:
|
||||
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group")
|
||||
else:
|
||||
reply.content = "Hello, " + msg['User'].get('NickName', "My friend")
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
|
||||
if content == "Hi":
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = "Hi"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply
|
||||
|
||||
if content == "End":
|
||||
# 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World"
|
||||
e_context['context'].type = "IMAGE_CREATE"
|
||||
content = "The World"
|
||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||
3
plugins/plugin.py
Normal file
3
plugins/plugin.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class Plugin:
|
||||
def __init__(self):
|
||||
self.handlers = {}
|
||||
171
plugins/plugin_manager.py
Normal file
171
plugins/plugin_manager.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from common.singleton import singleton
|
||||
from common.sorted_dict import SortedDict
|
||||
from .event import *
|
||||
from .plugin import *
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@singleton
|
||||
class PluginManager:
|
||||
def __init__(self):
|
||||
self.plugins = SortedDict(lambda k,v: v.priority,reverse=True)
|
||||
self.listening_plugins = {}
|
||||
self.instances = {}
|
||||
self.pconf = {}
|
||||
|
||||
def register(self, name: str, desc: str, version: str, author: str, desire_priority: int = 0):
|
||||
def wrapper(plugincls):
|
||||
plugincls.name = name
|
||||
plugincls.desc = desc
|
||||
plugincls.version = version
|
||||
plugincls.author = author
|
||||
plugincls.priority = desire_priority
|
||||
plugincls.enabled = True
|
||||
self.plugins[name.upper()] = plugincls
|
||||
logger.info("Plugin %s_v%s registered" % (name, version))
|
||||
return plugincls
|
||||
return wrapper
|
||||
|
||||
def save_config(self):
|
||||
with open("plugins/plugins.json", "w", encoding="utf-8") as f:
|
||||
json.dump(self.pconf, f, indent=4, ensure_ascii=False)
|
||||
|
||||
def load_config(self):
|
||||
logger.info("Loading plugins config...")
|
||||
|
||||
modified = False
|
||||
if os.path.exists("plugins/plugins.json"):
|
||||
with open("plugins/plugins.json", "r", encoding="utf-8") as f:
|
||||
pconf = json.load(f)
|
||||
pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True)
|
||||
else:
|
||||
modified = True
|
||||
pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)}
|
||||
self.pconf = pconf
|
||||
if modified:
|
||||
self.save_config()
|
||||
return pconf
|
||||
|
||||
def scan_plugins(self):
|
||||
logger.info("Scaning plugins ...")
|
||||
plugins_dir = "plugins"
|
||||
for plugin_name in os.listdir(plugins_dir):
|
||||
plugin_path = os.path.join(plugins_dir, plugin_name)
|
||||
if os.path.isdir(plugin_path):
|
||||
# 判断插件是否包含同名.py文件
|
||||
main_module_path = os.path.join(plugin_path, plugin_name+".py")
|
||||
if os.path.isfile(main_module_path):
|
||||
# 导入插件
|
||||
import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name)
|
||||
main_module = importlib.import_module(import_path)
|
||||
pconf = self.pconf
|
||||
new_plugins = []
|
||||
modified = False
|
||||
for name, plugincls in self.plugins.items():
|
||||
rawname = plugincls.name
|
||||
if rawname not in pconf["plugins"]:
|
||||
new_plugins.append(plugincls)
|
||||
modified = True
|
||||
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
|
||||
pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority}
|
||||
else:
|
||||
self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"]
|
||||
self.plugins[name].priority = pconf["plugins"][rawname]["priority"]
|
||||
self.plugins._update_heap(name) # 更新下plugins中的顺序
|
||||
if modified:
|
||||
self.save_config()
|
||||
return new_plugins
|
||||
|
||||
def refresh_order(self):
|
||||
for event in self.listening_plugins.keys():
|
||||
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
|
||||
|
||||
def activate_plugins(self): # 生成新开启的插件实例
|
||||
for name, plugincls in self.plugins.items():
|
||||
if plugincls.enabled:
|
||||
if name not in self.instances:
|
||||
instance = plugincls()
|
||||
self.instances[name] = instance
|
||||
for event in instance.handlers:
|
||||
if event not in self.listening_plugins:
|
||||
self.listening_plugins[event] = []
|
||||
self.listening_plugins[event].append(name)
|
||||
self.refresh_order()
|
||||
|
||||
def reload_plugin(self, name:str):
|
||||
name = name.upper()
|
||||
if name in self.instances:
|
||||
for event in self.listening_plugins:
|
||||
if name in self.listening_plugins[event]:
|
||||
self.listening_plugins[event].remove(name)
|
||||
del self.instances[name]
|
||||
self.activate_plugins()
|
||||
return True
|
||||
return False
|
||||
|
||||
def load_plugins(self):
|
||||
self.load_config()
|
||||
self.scan_plugins()
|
||||
pconf = self.pconf
|
||||
logger.debug("plugins.json config={}".format(pconf))
|
||||
for name,plugin in pconf["plugins"].items():
|
||||
if name.upper() not in self.plugins:
|
||||
logger.error("Plugin %s not found, but found in plugins.json" % name)
|
||||
self.activate_plugins()
|
||||
|
||||
def emit_event(self, e_context: EventContext, *args, **kwargs):
|
||||
if e_context.event in self.listening_plugins:
|
||||
for name in self.listening_plugins[e_context.event]:
|
||||
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
|
||||
logger.debug("Plugin %s triggered by event %s" % (name,e_context.event))
|
||||
instance = self.instances[name]
|
||||
instance.handlers[e_context.event](e_context, *args, **kwargs)
|
||||
return e_context
|
||||
|
||||
def set_plugin_priority(self, name:str, priority:int):
|
||||
name = name.upper()
|
||||
if name not in self.plugins:
|
||||
return False
|
||||
if self.plugins[name].priority == priority:
|
||||
return True
|
||||
self.plugins[name].priority = priority
|
||||
self.plugins._update_heap(name)
|
||||
rawname = self.plugins[name].name
|
||||
self.pconf["plugins"][rawname]["priority"] = priority
|
||||
self.pconf["plugins"]._update_heap(rawname)
|
||||
self.save_config()
|
||||
self.refresh_order()
|
||||
return True
|
||||
|
||||
def enable_plugin(self, name:str):
|
||||
name = name.upper()
|
||||
if name not in self.plugins:
|
||||
return False
|
||||
if not self.plugins[name].enabled :
|
||||
self.plugins[name].enabled = True
|
||||
rawname = self.plugins[name].name
|
||||
self.pconf["plugins"][rawname]["enabled"] = True
|
||||
self.save_config()
|
||||
self.activate_plugins()
|
||||
return True
|
||||
return True
|
||||
|
||||
def disable_plugin(self, name:str):
|
||||
name = name.upper()
|
||||
if name not in self.plugins:
|
||||
return False
|
||||
if self.plugins[name].enabled :
|
||||
self.plugins[name].enabled = False
|
||||
rawname = self.plugins[name].name
|
||||
self.pconf["plugins"][rawname]["enabled"] = False
|
||||
self.save_config()
|
||||
return True
|
||||
return True
|
||||
|
||||
def list_plugins(self):
|
||||
return self.plugins
|
||||
0
plugins/sdwebui/__init__.py
Normal file
0
plugins/sdwebui/__init__.py
Normal file
70
plugins/sdwebui/config.json.template
Normal file
70
plugins/sdwebui/config.json.template
Normal file
@@ -0,0 +1,70 @@
|
||||
{
|
||||
"start":{
|
||||
"host" : "127.0.0.1",
|
||||
"port" : 7860
|
||||
},
|
||||
"defaults": {
|
||||
"params": {
|
||||
"sampler_name": "DPM++ 2M Karras",
|
||||
"steps": 20,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"cfg_scale": 7,
|
||||
"prompt":"masterpiece, best quality",
|
||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
|
||||
"enable_hr": false,
|
||||
"hr_scale": 2,
|
||||
"hr_upscaler": "Latent",
|
||||
"hr_second_pass_steps": 15,
|
||||
"denoising_strength": 0.7
|
||||
},
|
||||
"options": {
|
||||
"sd_model_checkpoint": "perfectWorld_v2Baked"
|
||||
}
|
||||
},
|
||||
"rules": [
|
||||
{
|
||||
"keywords": [
|
||||
"横版",
|
||||
"壁纸"
|
||||
],
|
||||
"params": {
|
||||
"width": 640,
|
||||
"height": 384
|
||||
},
|
||||
"desc": "分辨率会变成640x384"
|
||||
},
|
||||
{
|
||||
"keywords": [
|
||||
"竖版"
|
||||
],
|
||||
"params": {
|
||||
"width": 384,
|
||||
"height": 640
|
||||
}
|
||||
},
|
||||
{
|
||||
"keywords": [
|
||||
"高清"
|
||||
],
|
||||
"params": {
|
||||
"enable_hr": true,
|
||||
"hr_scale": 1.6
|
||||
},
|
||||
"desc": "出图分辨率长宽都会提高1.6倍"
|
||||
},
|
||||
{
|
||||
"keywords": [
|
||||
"二次元"
|
||||
],
|
||||
"params": {
|
||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
|
||||
"prompt": "masterpiece, best quality"
|
||||
},
|
||||
"options": {
|
||||
"sd_model_checkpoint": "meinamix_meinaV8"
|
||||
},
|
||||
"desc": "使用二次元风格模型出图"
|
||||
}
|
||||
]
|
||||
}
|
||||
69
plugins/sdwebui/readme.md
Normal file
69
plugins/sdwebui/readme.md
Normal file
@@ -0,0 +1,69 @@
|
||||
### 插件描述
|
||||
本插件用于将画图请求转发给stable diffusion webui。
|
||||
|
||||
### 环境要求
|
||||
使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api"。
|
||||
具体信息,请参考[文章](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)。
|
||||
|
||||
请**安装**本插件的依赖包```webuiapi```
|
||||
```
|
||||
```pip install webuiapi```
|
||||
```
|
||||
### 使用说明
|
||||
请将`config.json.template`复制为`config.json`,并修改其中的参数和规则。
|
||||
|
||||
#### 画图请求格式
|
||||
用户的画图请求格式为:
|
||||
```
|
||||
<画图触发词><关键词1> <关键词2> ... <关键词n>:<prompt>
|
||||
```
|
||||
- 本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。
|
||||
- 规则的匹配顺序参考`config.json`中的顺序,每个关键词最多被匹配到1次,如果多个关键词触发了重复的参数,重复参数以最后一个关键词为准:
|
||||
- 关键词中包含`help`或`帮助`,会打印出帮助文档。
|
||||
第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后
|
||||
|
||||
例如: 画横版 高清 二次元:cat
|
||||
会触发三个关键词 "横版", "高清", "二次元",prompt为"cat"
|
||||
若默认参数是:
|
||||
```
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"enable_hr": false,
|
||||
"prompt": "8k"
|
||||
"negative_prompt": "nsfw",
|
||||
"sd_model_checkpoint": "perfectWorld_v2Baked"
|
||||
```
|
||||
|
||||
"横版"触发的规则参数为:
|
||||
```
|
||||
"width": 640,
|
||||
"height": 384,
|
||||
```
|
||||
"高清"触发的规则参数为:
|
||||
```
|
||||
"enable_hr": true,
|
||||
"hr_scale": 1.6,
|
||||
```
|
||||
"二次元"触发的规则参数为:
|
||||
```
|
||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
|
||||
"steps": 20,
|
||||
"prompt": "masterpiece, best quality",
|
||||
|
||||
"sd_model_checkpoint": "meinamix_meinaV8"
|
||||
```
|
||||
最后将第一个":"后的内容cat连接在prompt后,得到最终参数为:
|
||||
```
|
||||
"width": 640,
|
||||
"height": 384,
|
||||
"enable_hr": true,
|
||||
"hr_scale": 1.6,
|
||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
|
||||
"steps": 20,
|
||||
"prompt": "masterpiece, best quality, cat",
|
||||
|
||||
"sd_model_checkpoint": "meinamix_meinaV8"
|
||||
```
|
||||
PS: 参数分为两部分:
|
||||
- 一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致
|
||||
- 另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。
|
||||
114
plugins/sdwebui/sdwebui.py
Normal file
114
plugins/sdwebui/sdwebui.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import os
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf
|
||||
import plugins
|
||||
from plugins import *
|
||||
from common.log import logger
|
||||
import webuiapi
|
||||
import io
|
||||
|
||||
|
||||
@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent")
|
||||
class SDWebUI(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
self.rules = config["rules"]
|
||||
defaults = config["defaults"]
|
||||
self.default_params = defaults["params"]
|
||||
self.default_options = defaults["options"]
|
||||
self.start_args = config["start"]
|
||||
self.api = webuiapi.WebUIApi(**self.start_args)
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[SD] inited")
|
||||
except FileNotFoundError:
|
||||
logger.error(f"[SD] init failed, {config_path} not found")
|
||||
except Exception as e:
|
||||
logger.error("[SD] init failed, exception: %s" % e)
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
|
||||
if e_context['context'].type != ContextType.IMAGE_CREATE:
|
||||
return
|
||||
|
||||
logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content)
|
||||
|
||||
logger.info("[SD] image_query={}".format(e_context['context'].content))
|
||||
reply = Reply()
|
||||
try:
|
||||
content = e_context['context'].content[:]
|
||||
# 解析用户输入 如"横版 高清 二次元:cat"
|
||||
if ":" in content:
|
||||
keywords, prompt = content.split(":", 1)
|
||||
else:
|
||||
keywords = content
|
||||
prompt = ""
|
||||
|
||||
keywords = keywords.split()
|
||||
|
||||
if "help" in keywords or "帮助" in keywords:
|
||||
reply.type = ReplyType.INFO
|
||||
reply.content = self.get_help_text()
|
||||
else:
|
||||
rule_params = {}
|
||||
rule_options = {}
|
||||
for keyword in keywords:
|
||||
matched = False
|
||||
for rule in self.rules:
|
||||
if keyword in rule["keywords"]:
|
||||
for key in rule["params"]:
|
||||
rule_params[key] = rule["params"][key]
|
||||
if "options" in rule:
|
||||
for key in rule["options"]:
|
||||
rule_options[key] = rule["options"][key]
|
||||
matched = True
|
||||
break # 一个关键词只匹配一个规则
|
||||
if not matched:
|
||||
logger.warning("[SD] keyword not matched: %s" % keyword)
|
||||
|
||||
params = {**self.default_params, **rule_params}
|
||||
options = {**self.default_options, **rule_options}
|
||||
params["prompt"] = params.get("prompt", "")+f", {prompt}"
|
||||
if len(options) > 0:
|
||||
logger.info("[SD] cover options={}".format(options))
|
||||
self.api.set_options(options)
|
||||
logger.info("[SD] params={}".format(params))
|
||||
result = self.api.txt2img(
|
||||
**params
|
||||
)
|
||||
reply.type = ReplyType.IMAGE
|
||||
b_img = io.BytesIO()
|
||||
result.image.save(b_img, format="PNG")
|
||||
reply.content = b_img
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑
|
||||
except Exception as e:
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "[SD] "+str(e)
|
||||
logger.error("[SD] exception: %s" % e)
|
||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||
finally:
|
||||
e_context['reply'] = reply
|
||||
|
||||
def get_help_text(self):
|
||||
if not conf().get('image_create_prefix'):
|
||||
return "画图功能未启用"
|
||||
else:
|
||||
trigger = conf()['image_create_prefix'][0]
|
||||
help_text = f"请使用<{trigger}[关键词1] [关键词2]...:提示语>的格式作画,如\"{trigger}横版 高清:cat\"\n"
|
||||
help_text += "目前可用关键词:\n"
|
||||
for rule in self.rules:
|
||||
keywords = [f"[{keyword}]" for keyword in rule['keywords']]
|
||||
help_text += f"{','.join(keywords)}"
|
||||
if "desc" in rule:
|
||||
help_text += f"-{rule['desc']}\n"
|
||||
else:
|
||||
help_text += "\n"
|
||||
return help_text
|
||||
@@ -4,6 +4,7 @@ baidu voice service
|
||||
"""
|
||||
import time
|
||||
from aip import AipSpeech
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from voice.voice import Voice
|
||||
@@ -30,7 +31,8 @@ class BaiduVoice(Voice):
|
||||
with open(fileName, 'wb') as f:
|
||||
f.write(result)
|
||||
logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
|
||||
return fileName
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
else:
|
||||
logger.error('[Baidu] textToVoice error={}'.format(result))
|
||||
return None
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
||||
return reply
|
||||
|
||||
@@ -6,6 +6,7 @@ google voice service
|
||||
import pathlib
|
||||
import subprocess
|
||||
import time
|
||||
from bridge.reply import Reply, ReplyType
|
||||
import speech_recognition
|
||||
import pyttsx3
|
||||
from common.log import logger
|
||||
@@ -36,16 +37,22 @@ class GoogleVoice(Voice):
|
||||
text = self.recognizer.recognize_google(audio, language='zh-CN')
|
||||
logger.info(
|
||||
'[Google] voiceToText text={} voice file name={}'.format(text, voice_file))
|
||||
return text
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
except speech_recognition.UnknownValueError:
|
||||
return "抱歉,我听不懂。"
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
|
||||
except speech_recognition.RequestError as e:
|
||||
return "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)
|
||||
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e))
|
||||
finally:
|
||||
return reply
|
||||
def textToVoice(self, text):
|
||||
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
|
||||
self.engine.save_to_file(text, textFile)
|
||||
self.engine.runAndWait()
|
||||
logger.info(
|
||||
'[Google] textToVoice text={} voice file name={}'.format(text, textFile))
|
||||
return textFile
|
||||
try:
|
||||
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
|
||||
self.engine.save_to_file(text, textFile)
|
||||
self.engine.runAndWait()
|
||||
logger.info(
|
||||
'[Google] textToVoice text={} voice file name={}'.format(text, textFile))
|
||||
reply = Reply(ReplyType.VOICE, textFile)
|
||||
except Exception as e:
|
||||
reply = Reply(ReplyType.ERROR, str(e))
|
||||
finally:
|
||||
return reply
|
||||
|
||||
@@ -4,6 +4,7 @@ google voice service
|
||||
"""
|
||||
import json
|
||||
import openai
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from voice.voice import Voice
|
||||
@@ -16,12 +17,17 @@ class OpenaiVoice(Voice):
|
||||
def voiceToText(self, voice_file):
|
||||
logger.debug(
|
||||
'[Openai] voice file name={}'.format(voice_file))
|
||||
file = open(voice_file, "rb")
|
||||
reply = openai.Audio.transcribe("whisper-1", file)
|
||||
text = reply["text"]
|
||||
logger.info(
|
||||
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file))
|
||||
return text
|
||||
try:
|
||||
file = open(voice_file, "rb")
|
||||
result = openai.Audio.transcribe("whisper-1", file)
|
||||
text = result["text"]
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
logger.info(
|
||||
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file))
|
||||
except Exception as e:
|
||||
reply = Reply(ReplyType.ERROR, str(e))
|
||||
finally:
|
||||
return reply
|
||||
|
||||
def textToVoice(self, text):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user