mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-20 09:40:36 +08:00
feat: add linkai plugin to support midjourney and distinguish app between groups
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -24,4 +24,5 @@ plugins/**/
|
||||
!plugins/banwords/**/
|
||||
!plugins/hello
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
@@ -52,7 +52,7 @@ class LinkAIBot(Bot, OpenAIImage):
|
||||
logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
|
||||
app_code = None
|
||||
else:
|
||||
app_code = conf().get("linkai_app_code")
|
||||
app_code = context.kwargs.get("app_code") or conf().get("linkai_app_code")
|
||||
linkai_api_key = conf().get("linkai_api_key")
|
||||
|
||||
session_id = context["session_id"]
|
||||
|
||||
1
plugins/linkai/__init__.py
Normal file
1
plugins/linkai/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .linkai import *
|
||||
93
plugins/linkai/linkai.py
Normal file
93
plugins/linkai/linkai.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import plugins
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from plugins import *
|
||||
from .midjourney import MJBot, TaskType
|
||||
|
||||
# 任务线程池
|
||||
task_thread_pool = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
@plugins.register(
|
||||
name="linkai",
|
||||
desc="A plugin that supports knowledge base and midjourney drawing.",
|
||||
version="0.1.0",
|
||||
author="https://link-ai.tech",
|
||||
)
|
||||
class LinkAI(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
self.config = super().load_config()
|
||||
self.mj_bot = MJBot(self.config.get("midjourney"))
|
||||
logger.info("[LinkAI] inited")
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
"""
|
||||
消息处理逻辑
|
||||
:param e_context: 消息上下文
|
||||
"""
|
||||
context = e_context['context']
|
||||
if context.type not in [ContextType.TEXT, ContextType.IMAGE]:
|
||||
# filter content no need solve
|
||||
return
|
||||
|
||||
mj_type = self.mj_bot.judge_mj_task_type(e_context)
|
||||
if mj_type:
|
||||
# MJ作图任务处理
|
||||
self.mj_bot.process_mj_task(mj_type, e_context)
|
||||
return
|
||||
|
||||
if self._is_chat_task(e_context):
|
||||
self._process_chat_task(e_context)
|
||||
|
||||
# LinkAI 对话任务处理
|
||||
def _is_chat_task(self, e_context: EventContext):
|
||||
context = e_context['context']
|
||||
# 群聊应用管理
|
||||
return self.config.get("knowledge_base") and context.kwargs.get("isgroup")
|
||||
|
||||
def _process_chat_task(self, e_context: EventContext):
|
||||
"""
|
||||
处理LinkAI对话任务
|
||||
:param e_context: 对话上下文
|
||||
"""
|
||||
context = e_context['context']
|
||||
# 群聊应用管理
|
||||
group_name = context.kwargs.get("msg").from_user_nickname
|
||||
app_code = self._fetch_group_app_code(group_name)
|
||||
if app_code:
|
||||
context.kwargs['app_code'] = app_code
|
||||
|
||||
def _fetch_group_app_code(self, group_name: str) -> str:
|
||||
"""
|
||||
根据群聊名称获取对应的应用code
|
||||
:param group_name: 群聊名称
|
||||
:return: 应用code
|
||||
"""
|
||||
knowledge_base_config = self.config.get("knowledge_base")
|
||||
if knowledge_base_config and knowledge_base_config.get("group_mapping"):
|
||||
app_code = knowledge_base_config.get("group_mapping").get(group_name) \
|
||||
or knowledge_base_config.get("group_mapping").get("ALL_GROUP")
|
||||
return app_code
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
help_text = "利用midjourney来画图。\n"
|
||||
if not verbose:
|
||||
return help_text
|
||||
help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
|
||||
return help_text
|
||||
|
||||
def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
|
||||
reply = Reply(level, content)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
256
plugins/linkai/midjourney.py
Normal file
256
plugins/linkai/midjourney.py
Normal file
@@ -0,0 +1,256 @@
|
||||
from enum import Enum
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import requests
|
||||
import threading
|
||||
import time
|
||||
from bridge.reply import Reply, ReplyType
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from bridge.context import ContextType
|
||||
from plugins import EventContext, EventAction
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
GENERATE = "generate"
|
||||
UPSCALE = "upscale"
|
||||
VARIATION = "variation"
|
||||
RESET = "reset"
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
PENDING = "pending"
|
||||
FINISHED = "finished"
|
||||
EXPIRED = "expired"
|
||||
ABORTED = "aborted"
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class MJTask:
|
||||
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING):
|
||||
self.id = id
|
||||
self.user_id = user_id
|
||||
self.task_type = task_type
|
||||
self.raw_prompt = raw_prompt
|
||||
self.send_func = None # send_func(img_url)
|
||||
self.expiry_time = time.time() + expires
|
||||
self.status = status
|
||||
self.img_url = None # url
|
||||
self.img_id = None
|
||||
|
||||
def __str__(self):
|
||||
return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
|
||||
|
||||
# midjourney bot
|
||||
class MJBot:
|
||||
def __init__(self, config):
|
||||
self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
|
||||
# self.base_url = "http://127.0.0.1:8911/v1/img/midjourney"
|
||||
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
self.config = config
|
||||
self.tasks = {}
|
||||
self.temp_dict = {}
|
||||
self.tasks_lock = threading.Lock()
|
||||
self.event_loop = asyncio.new_event_loop()
|
||||
threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start()
|
||||
|
||||
def judge_mj_task_type(self, e_context: EventContext) -> TaskType:
|
||||
"""
|
||||
判断MJ任务的类型
|
||||
:param e_context: 上下文
|
||||
:return: 任务类型枚举
|
||||
"""
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
context = e_context['context']
|
||||
if context.type == ContextType.TEXT:
|
||||
if self.config and self.config.get("enabled"):
|
||||
cmd_list = context.content.split(maxsplit=1)
|
||||
if cmd_list[0].lower() == f"{trigger_prefix}mj":
|
||||
return TaskType.GENERATE
|
||||
elif cmd_list[0].lower() == f"{trigger_prefix}mju":
|
||||
return TaskType.UPSCALE
|
||||
# elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
|
||||
# return TaskType.VARIATION
|
||||
# elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
|
||||
# return TaskType.RESET
|
||||
|
||||
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
|
||||
"""
|
||||
处理mj任务
|
||||
:param mj_type: mj任务类型
|
||||
:param e_context: 对话上下文
|
||||
"""
|
||||
context = e_context['context']
|
||||
session_id = context["session_id"]
|
||||
cmd = context.content.split(maxsplit=1)
|
||||
if len(cmd) == 1:
|
||||
self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.ERROR)
|
||||
return
|
||||
|
||||
if mj_type == TaskType.GENERATE:
|
||||
# 图片生成
|
||||
raw_prompt = cmd[1]
|
||||
reply = self.generate(raw_prompt, session_id, e_context)
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
elif mj_type == TaskType.UPSCALE:
|
||||
# 图片放大
|
||||
clist = cmd[1].split()
|
||||
if len(clist) < 2:
|
||||
self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
|
||||
return
|
||||
img_id = clist[0]
|
||||
index = int(clist[1])
|
||||
if index < 1 or index > 4:
|
||||
self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context)
|
||||
return
|
||||
key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
|
||||
if self.temp_dict.get(key):
|
||||
self._set_reply_text(f"第 {index} 张图片已经放大过了", e_context)
|
||||
return
|
||||
# 图片放大操作
|
||||
reply = self.upscale(session_id, img_id, index, e_context)
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
else:
|
||||
self._set_reply_text(f"暂不支持该命令", e_context)
|
||||
|
||||
def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply:
|
||||
"""
|
||||
图片生成
|
||||
:param prompt: 提示词
|
||||
:param user_id: 用户id
|
||||
:return: 任务ID
|
||||
"""
|
||||
logger.info(f"[MJ] image generate, prompt={prompt}")
|
||||
body = {"prompt": prompt}
|
||||
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers)
|
||||
if res.status_code == 200:
|
||||
res = res.json()
|
||||
logger.debug(f"[MJ] image generate, res={res}")
|
||||
if res.get("code") == 200:
|
||||
task_id = res.get("data").get("taskId")
|
||||
real_prompt = res.get("data").get("realPrompt")
|
||||
content = f"🚀你的作品将在1~2分钟左右完成,请耐心等待\n- - - - - - - - -\n"
|
||||
if real_prompt:
|
||||
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
|
||||
else:
|
||||
content += f"prompt: {prompt}"
|
||||
reply = Reply(ReplyType.INFO, content)
|
||||
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE)
|
||||
# put to memory dict
|
||||
self.tasks[task.id] = task
|
||||
asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
||||
return reply
|
||||
else:
|
||||
res_json = res.json()
|
||||
logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
|
||||
reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
|
||||
return reply
|
||||
|
||||
def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
|
||||
logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
|
||||
body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index}
|
||||
res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers)
|
||||
if res.status_code == 200:
|
||||
res = res.json()
|
||||
logger.info(res)
|
||||
if res.get("code") == 200:
|
||||
task_id = res.get("data").get("taskId")
|
||||
content = f"🔎图片正在放大中,请耐心等待"
|
||||
reply = Reply(ReplyType.INFO, content)
|
||||
task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=TaskType.UPSCALE)
|
||||
# put to memory dict
|
||||
self.tasks[task.id] = task
|
||||
key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
|
||||
self.temp_dict[key] = True
|
||||
asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
||||
return reply
|
||||
else:
|
||||
error_msg = ""
|
||||
if res.status_code == 461:
|
||||
error_msg = "请输入正确的图片ID"
|
||||
res_json = res.json()
|
||||
logger.error(f"[MJ] upscale error, msg={res_json.get('message')}, status_code={res.status_code}")
|
||||
reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
|
||||
return reply
|
||||
|
||||
async def check_task(self, task: MJTask, e_context: EventContext):
|
||||
max_retry_time = 80
|
||||
while max_retry_time > 0:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = f"{self.base_url}/tasks/{task.id}"
|
||||
async with session.get(url, headers=self.headers) as res:
|
||||
if res.status == 200:
|
||||
res_json = await res.json()
|
||||
logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, "
|
||||
f"data={res_json.get('data')}, thread={threading.current_thread().name}")
|
||||
if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
|
||||
# process success res
|
||||
self._process_success_task(task, res_json.get("data"), e_context)
|
||||
return
|
||||
else:
|
||||
logger.warn(f"[MJ] image check error, status_code={res.status}")
|
||||
max_retry_time -= 20
|
||||
await asyncio.sleep(10)
|
||||
max_retry_time -= 1
|
||||
logger.warn("[MJ] end from poll")
|
||||
|
||||
def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
|
||||
"""
|
||||
处理任务成功的结果
|
||||
:param task: MJ任务
|
||||
:param res: 请求结果
|
||||
:param e_context: 对话上下文
|
||||
"""
|
||||
# channel send img
|
||||
task.status = Status.FINISHED
|
||||
task.img_id = res.get("imgId")
|
||||
task.img_url = res.get("imgUrl")
|
||||
logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
|
||||
|
||||
# send img
|
||||
reply = Reply(ReplyType.IMAGE_URL, task.img_url)
|
||||
channel = e_context["channel"]
|
||||
channel._send(reply, e_context["context"])
|
||||
|
||||
# send info
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
text = ""
|
||||
if task.task_type == TaskType.GENERATE:
|
||||
text = f"🎨绘画完成!\nprompt: {task.raw_prompt}\n- - - - - - - - -\n图片ID: {task.img_id}"
|
||||
text += f"\n\n🔎可使用 {trigger_prefix}mju 命令放大指定图片\n"
|
||||
text += f"例如:\n{trigger_prefix}mju {task.img_id} 1"
|
||||
reply = Reply(ReplyType.INFO, text)
|
||||
channel._send(reply, e_context["context"])
|
||||
|
||||
self._print_tasks()
|
||||
return
|
||||
|
||||
def _run_loop(self, loop: asyncio.BaseEventLoop):
|
||||
loop.run_forever()
|
||||
loop.stop()
|
||||
|
||||
def _print_tasks(self):
|
||||
for id in self.tasks:
|
||||
logger.debug(f"[MJ] current task: {self.tasks[id]}")
|
||||
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
help_text = "利用midjourney来画图。\n"
|
||||
if not verbose:
|
||||
return help_text
|
||||
help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
|
||||
return help_text
|
||||
|
||||
def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
|
||||
reply = Reply(level, content)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
Reference in New Issue
Block a user