From 28e82bc693c998d9b03945422b9e369b0012ec9f Mon Sep 17 00:00:00 2001 From: tercel Date: Mon, 10 Apr 2023 15:23:27 +0800 Subject: [PATCH 1/2] 1.added discord channel 2.added logger and common.certificate_file in config-template.json --- app.py | 1 + channel/channel_factory.py | 4 + channel/discord/discord_channel.py | 167 +++++++++++++++++++++++++++++ common/const.py | 1 + common/log.py | 10 +- config-template.json | 12 ++- requirements.txt | 3 +- 7 files changed, 194 insertions(+), 4 deletions(-) create mode 100644 channel/discord/discord_channel.py diff --git a/app.py b/app.py index 3502d86..5e61069 100644 --- a/app.py +++ b/app.py @@ -18,6 +18,7 @@ def start_process(channel_type, config_path): channel.startup() except Exception as e: log.error("[MultiChannel] Start up failed on {}: {}", channel_type, str(e)) + raise e def main(): diff --git a/channel/channel_factory.py b/channel/channel_factory.py index 1346499..8ff2740 100644 --- a/channel/channel_factory.py +++ b/channel/channel_factory.py @@ -53,5 +53,9 @@ def create_channel(channel_type): from channel.feishu.feishu_channel import FeiShuChannel return FeiShuChannel() + elif channel_type == const.DISCORD: + from channel.discord.discord_channel import DiscordChannel + return DiscordChannel() + else: raise RuntimeError("unknown channel_type in config.json: " + channel_type) diff --git a/channel/discord/discord_channel.py b/channel/discord/discord_channel.py new file mode 100644 index 0000000..d23d31f --- /dev/null +++ b/channel/discord/discord_channel.py @@ -0,0 +1,167 @@ +# encoding:utf-8 + +""" +discord channel +Python discord - https://github.com/Rapptz/discord.py.git +""" +from channel.channel import Channel +from common.log import logger +from config import conf, common_conf_val, channel_conf +import ssl +import discord +from discord.ext import commands + +class DiscordChannel(Channel): + + def __init__(self): + config = conf() + + self.token = channel_conf('discord').get('app_token') + self.discord_channel_name = channel_conf('discord').get('channel_name') + self.discord_channel_session = channel_conf('discord').get('channel_session', 'author') + self.voice_enabled = channel_conf('discord').get('voice_enabled', False) + self.cmd_clear_session = common_conf_val('clear_memory_commands', ['#清除记忆'])[0] + self.sessions = [] + self.intents = discord.Intents.default() + self.intents.message_content = True + self.intents.guilds = True + self.intents.members = True + self.intents.messages = True + self.intents.voice_states = True + + context = ssl.create_default_context() + context.load_verify_locations(common_conf_val('certificate_file')) + self.bot = commands.Bot(command_prefix='!', intents=self.intents, ssl=context) + self.bot.add_listener(self.on_ready) + + logger.debug('cmd_clear_session %s', self.cmd_clear_session) + + def startup(self): + self.bot.add_listener(self.on_message) + self.bot.add_listener(self.on_guild_channel_delete) + self.bot.add_listener(self.on_guild_channel_create) + self.bot.add_listener(self.on_private_channel_delete) + self.bot.add_listener(self.on_private_channel_create) + self.bot.add_listener(self.on_channel_delete) + self.bot.add_listener(self.on_channel_create) + self.bot.add_listener(self.on_thread_delete) + self.bot.add_listener(self.on_thread_create) + self.bot.run(self.token) + + async def on_ready(self): + logger.info('Bot is online user:{}'.format(self.bot.user)) + if self.voice_enabled == False: + logger.debug('disable music') + await self.bot.remove_cog("Music") + + async def join(self, ctx): + logger.debug('join %s', repr(ctx)) + channel = ctx.author.voice.channel + await channel.connect() + + async def _do_on_channel_delete(self, channel): + if not self.discord_channel_name or channel.name != self.discord_channel_name: + logger.debug('skip _do_on_channel_delete %s', channel.name) + return + + for name in self.sessions: + try: + response = self.send_text(name, self.cmd_clear_session) + logger.debug('_do_on_channel_delete %s %s', channel.name, response) + except Exception as e: + logger.warn('clear session except, id:%s', name) + + self.sessions.clear() + + async def on_guild_channel_delete(self, channel): + logger.debug('on_guild_channel_delete %s', repr(channel)) + await self._do_on_channel_delete(channel) + + async def on_guild_channel_create(self, channel): + logger.debug('on_guild_channel_create %s', repr(channel)) + + async def on_private_channel_delete(self, channel): + logger.debug('on_channel_delete %s', repr(channel)) + await self._do_on_channel_delete(channel) + + async def on_private_channel_create(self, channel): + logger.debug('on_channel_create %s', repr(channel)) + + async def on_channel_delete(self, channel): + logger.debug('on_channel_delete %s', repr(channel)) + + async def on_channel_create(self, channel): + logger.debug('on_channel_create %s', repr(channel)) + + async def on_thread_delete(self, thread): + print('on_thread_delete', thread) + if self.discord_channel_session != 'thread' or thread.parent.name != self.discord_channel_name: + logger.debug('skip on_thread_delete %s', thread.id) + return + + try: + response = self.send_text(thread.id, self.cmd_clear_session) + if thread.id in self.sessions: + self.sessions.remove(thread.id) + logger.debug('on_thread_delete %s %s', thread.id, response) + except Exception as e: + logger.warn('on_thread_delete except %s', thread.id) + raise e + + + async def on_thread_create(self, thread): + logger.debug('on_thread_create %s', thread.id) + if self.discord_channel_session != 'thread' or thread.parent.name != self.discord_channel_name: + logger.debug('skip on_channel_create %s', repr(thread)) + return + + self.sessions.append(thread.id) + + async def on_message(self, message): + """ + listen for message event + """ + await self.bot.wait_until_ready() + if not self.check_message(message): + return + + prompt = message.content.strip(); + logger.debug('author: %s', message.author) + logger.debug('prompt: %s', prompt) + + session_id = message.author + if self.discord_channel_session == 'thread' and isinstance(message.channel, discord.Thread): + logger.debug('on_message thread id %s', message.channel.id) + session_id = message.channel.id + + await message.channel.send('...') + response = response = self.send_text(session_id, prompt) + await message.channel.send(response) + + + def check_message(self, message): + if message.author == self.bot.user: + return False + + prompt = message.content.strip(); + if not prompt: + logger.debug('no prompt author: %s', message.author) + return False + + if self.discord_channel_name: + if isinstance(message.channel, discord.Thread) and message.channel.parent.name == self.discord_channel_name: + return True + if not isinstance(message.channel, discord.Thread) and self.discord_channel_session != 'thread' and message.channel.name == self.discord_channel_name: + return True + + logger.debug("The accessed channel does not meet the discord channel configuration conditions.") + return False + else: + return True + + def send_text(self, id, content): + context = dict() + context['type'] = 'TEXT' + context['from_user_id'] = id + context['content'] = content + return super().build_reply_content(content, context) \ No newline at end of file diff --git a/common/const.py b/common/const.py index bd28a5c..0529b48 100644 --- a/common/const.py +++ b/common/const.py @@ -10,6 +10,7 @@ SLACK = "slack" HTTP = "http" DINGTALK = "dingtalk" FEISHU = "feishu" +DISCORD = "discord" # model OPEN_AI = "openai" diff --git a/common/log.py b/common/log.py index 2510bcb..2d375c0 100644 --- a/common/log.py +++ b/common/log.py @@ -2,12 +2,18 @@ import logging import sys +import config + -SWITCH = True def _get_logger(): + global SWITCH + config.load_config() + SWITCH = config.conf().get("logger").get("switch", True) + log = logging.getLogger('log') - log.setLevel(logging.INFO) + level = config.conf().get("logger").get("level", logging.INFO) + log.setLevel(level) console_handle = logging.StreamHandler(sys.stdout) console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')) diff --git a/config-template.json b/config-template.json index 8b61dc2..79fa1db 100644 --- a/config-template.json +++ b/config-template.json @@ -1,4 +1,8 @@ { + "logger": { + "switch": true, + "level": "INFO" + }, "model": { "type" : "chatgpt", "openai": { @@ -85,9 +89,15 @@ "app_id": "xxx", "app_secret": "xxx", "verification_token": "xxx" + }, + "discord": { + "app_token": "xxx", + "channel_name": "xxx", + "channel_session": "xxx" } }, "common": { - "clear_memory_commands": ["#清除记忆"] + "clear_memory_commands": ["#清除记忆"], + "certificate_file": "xxx" } } \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1a46046..f4bb97f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ flask_socketio itchat-uos==1.5.0.dev0 openai EdgeGPT -requests \ No newline at end of file +requests +discord.py>=2.0.0 \ No newline at end of file From 533fce696098e0d749001a87121f57aff3d8d7d7 Mon Sep 17 00:00:00 2001 From: tercel Date: Mon, 10 Apr 2023 15:26:27 +0800 Subject: [PATCH 2/2] minor changes --- common/log.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/log.py b/common/log.py index 2d375c0..db1f2dd 100644 --- a/common/log.py +++ b/common/log.py @@ -4,13 +4,13 @@ import logging import sys import config - +SWITCH = True def _get_logger(): global SWITCH config.load_config() SWITCH = config.conf().get("logger").get("switch", True) - + log = logging.getLogger('log') level = config.conf().get("logger").get("level", logging.INFO) log.setLevel(level)