plugins: add sdwebui(stable diffusion) plugin

This commit is contained in:
lanvent
2023-03-14 00:49:28 +08:00
parent dce9c4dccb
commit e6d148e729
4 changed files with 218 additions and 0 deletions

View File

@@ -0,0 +1,88 @@
# encoding:utf-8
import json
import os
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
import plugins
from plugins import *
from common.log import logger
import webuiapi
import io
@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent")
class SDWebUI(Plugin):
def __init__(self):
super().__init__()
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
self.rules = config["rules"]
defaults = config["defaults"]
self.default_params = defaults["params"]
self.default_options = defaults["options"]
self.start_args = config["start"]
self.api = webuiapi.WebUIApi(**self.start_args)
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[SD] inited")
except FileNotFoundError:
logger.error(f"[SD] init failed, {config_path} not found")
except Exception as e:
logger.error("[SD] init failed, exception: %s" % e)
def on_handle_context(self, e_context: EventContext):
if e_context['context'].type != ContextType.IMAGE_CREATE:
return
logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content)
logger.info("[SD] image_query={}".format(e_context['context'].content))
reply = Reply()
try:
content = e_context['context'].content[:]
# 解析用户输入 如"横版 高清 二次元:cat"
keywords, prompt = content.split(":", 1)
keywords = keywords.split()
rule_params = {}
rule_options = {}
for keyword in keywords:
matched = False
for rule in self.rules:
if keyword in rule["keywords"]:
for key in rule["params"]:
rule_params[key] = rule["params"][key]
if "options" in rule:
for key in rule["options"]:
rule_options[key] = rule["options"][key]
matched = True
break # 一个关键词只匹配一个规则
if not matched:
logger.warning("[SD] keyword not matched: %s" % keyword)
params = {**self.default_params, **rule_params}
options = {**self.default_options, **rule_options}
params["prompt"] = params.get("prompt", "")+f", {prompt}"
if len(options) > 0:
logger.info("[SD] cover rule_options={}".format(rule_options))
self.api.set_options(options)
logger.info("[SD] params={}".format(params))
result = self.api.txt2img(
**params
)
reply.type = ReplyType.IMAGE
b_img = io.BytesIO()
result.image.save(b_img, format="PNG")
reply.content = b_img
e_context.action = EventAction.BREAK_PASS # 事件结束后不跳过处理context的默认逻辑
except Exception as e:
reply.type = ReplyType.ERROR
reply.content = "[SD] "+str(e)
logger.error("[SD] exception: %s" % e)
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
finally:
e_context['reply'] = reply