Files
chatgpt-on-wechat/channel/web/web_channel.py
2024-11-27 08:53:13 +08:00

205 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
import time
import web
import json
from queue import Queue
from bridge.context import *
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix
from channel.chat_message import ChatMessage
from common.log import logger
from common.singleton import singleton
from config import conf
import os
class WebMessage(ChatMessage):
def __init__(
self,
msg_id,
content,
ctype=ContextType.TEXT,
from_user_id="User",
to_user_id="Chatgpt",
other_user_id="Chatgpt",
):
self.msg_id = msg_id
self.ctype = ctype
self.content = content
self.from_user_id = from_user_id
self.to_user_id = to_user_id
self.other_user_id = other_user_id
@singleton
class WebChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
_instance = None
# def __new__(cls):
# if cls._instance is None:
# cls._instance = super(WebChannel, cls).__new__(cls)
# return cls._instance
def __init__(self):
super().__init__()
self.message_queues = {} # 为每个用户存储一个消息队列
self.msg_id_counter = 0 # 添加消息ID计数器
def _generate_msg_id(self):
"""生成唯一的消息ID"""
self.msg_id_counter += 1
return str(int(time.time())) + str(self.msg_id_counter)
def send(self, reply: Reply, context: Context):
try:
if reply.type == ReplyType.IMAGE:
from PIL import Image
image_storage = reply.content
image_storage.seek(0)
img = Image.open(image_storage)
print("<IMAGE>")
img.show()
elif reply.type == ReplyType.IMAGE_URL:
import io
import requests
from PIL import Image
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
img = Image.open(image_storage)
print(img_url)
img.show()
else:
print(reply.content)
# 获取用户ID如果没有则使用默认值
# user_id = getattr(context.get("session", None), "session_id", "default_user")
user_id = context["receiver"]
# 确保用户有对应的消息队列
if user_id not in self.message_queues:
self.message_queues[user_id] = Queue()
# 将消息放入对应用户的队列
message_data = {
"type": str(reply.type),
"content": reply.content,
"timestamp": time.time()
}
self.message_queues[user_id].put(message_data)
logger.debug(f"Message queued for user {user_id}")
except Exception as e:
logger.error(f"Error in send method: {e}")
raise
def sse_handler(self, user_id):
"""
Handle Server-Sent Events (SSE) for real-time communication.
"""
web.header('Content-Type', 'text/event-stream')
web.header('Cache-Control', 'no-cache')
web.header('Connection', 'keep-alive')
# 确保用户有消息队列
if user_id not in self.message_queues:
self.message_queues[user_id] = Queue()
try:
while True:
try:
# 发送心跳
yield f": heartbeat\n\n"
# 非阻塞方式获取消息
if not self.message_queues[user_id].empty():
message = self.message_queues[user_id].get_nowait()
yield f"data: {json.dumps(message)}\n\n"
time.sleep(0.5)
except Exception as e:
logger.error(f"SSE Error: {e}")
break
finally:
# 清理资源
if user_id in self.message_queues:
# 只有当队列为空时才删除
if self.message_queues[user_id].empty():
del self.message_queues[user_id]
def post_message(self):
"""
Handle incoming messages from users via POST request.
"""
try:
data = web.data() # 获取原始POST数据
json_data = json.loads(data)
user_id = json_data.get('user_id', 'default_user')
prompt = json_data.get('message', '')
except json.JSONDecodeError:
return json.dumps({"status": "error", "message": "Invalid JSON"})
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})
if not prompt:
return json.dumps({"status": "error", "message": "No message provided"})
try:
msg_id = self._generate_msg_id()
context = self._compose_context(ContextType.TEXT, prompt, msg=WebMessage(msg_id,
prompt,
from_user_id=user_id,
other_user_id = user_id
))
context["isgroup"] = False
# context["session"] = web.storage(session_id=user_id)
if not context:
return json.dumps({"status": "error", "message": "Failed to process message"})
self.produce(context)
return json.dumps({"status": "success", "message": "Message received"})
except Exception as e:
logger.error(f"Error processing message: {e}")
return json.dumps({"status": "error", "message": "Internal server error"})
def chat_page(self):
"""Serve the chat HTML page."""
file_path = os.path.join(os.path.dirname(__file__), 'chat.html') # 使用绝对路径
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
def startup(self):
logger.setLevel("WARN")
print("\nWeb Channel is running. Send POST requests to /message to send messages.")
urls = (
'/sse/(.+)', 'SSEHandler', # 修改路由以接收用户ID
'/message', 'MessageHandler',
'/chat', 'ChatHandler',
)
port = conf().get("web_port", 9899)
app = web.application(urls, globals(), autoreload=False)
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
class SSEHandler:
def GET(self, user_id):
return WebChannel().sse_handler(user_id)
class MessageHandler:
def POST(self):
return WebChannel().post_message()
class ChatHandler:
def GET(self):
return WebChannel().chat_page()