From 9e70809fdcdd625c9476b796aa70d0984d1d1aee Mon Sep 17 00:00:00 2001 From: zhayujie Date: Sat, 18 Mar 2023 12:14:40 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=87=E5=BF=83=E4=B8=80=E8=A8=80?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=89=88=E6=8E=A5=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- channel/qq/qq_channel.py | 1 + common/const.py | 1 + config-template.json | 4 + model/baidu/yiyan_model.py | 106 +++++++++++++++++++++ model/model_factory.py | 2 +- model/{chatgpt => openai}/chatgpt_model.py | 0 6 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 model/baidu/yiyan_model.py rename model/{chatgpt => openai}/chatgpt_model.py (100%) diff --git a/channel/qq/qq_channel.py b/channel/qq/qq_channel.py index e929fe6..f39a641 100644 --- a/channel/qq/qq_channel.py +++ b/channel/qq/qq_channel.py @@ -46,4 +46,5 @@ class QQChannel(Channel): context['from_user_id'] = str(msg.user_id) reply_text = super().build_reply_content(query, context) reply_text = '[CQ:at,qq=' + str(msg.user_id) + '] ' + reply_text + bot.sync.send_group_msg(group_id=msg['group_id'], message=reply_text) diff --git a/common/const.py b/common/const.py index c2e587c..5c10599 100644 --- a/common/const.py +++ b/common/const.py @@ -12,3 +12,4 @@ HTTP = "http" # model OPEN_AI = "openai" CHATGPT = "chatgpt" +BAIDU = "baidu" \ No newline at end of file diff --git a/config-template.json b/config-template.json index 4476d50..23be7fe 100644 --- a/config-template.json +++ b/config-template.json @@ -7,6 +7,10 @@ "proxy": "", "conversation_max_tokens": 1000, "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。" + }, + "baidu": { + "acs-token": "YOUR ACS TOKEN", + "cookie": "YOUR COOKIE" } }, diff --git a/model/baidu/yiyan_model.py b/model/baidu/yiyan_model.py new file mode 100644 index 0000000..545de6f --- /dev/null +++ b/model/baidu/yiyan_model.py @@ -0,0 +1,106 @@ +# encoding:utf-8 + +from model.model import Model +from config import model_conf +from common import const +from common import log +import requests +import time +import json + +sessions = {} + +class YiyanModel(Model): + def __init__(self): + self.acs_token = model_conf(const.BAIDU).get('acs_token') + self.cookie = model_conf(const.BAIDU).get('cookie') + self.base_url = 'https://yiyan.baidu.com/eb' + + def reply(self, query, context=None): + user_id = context.get('session_id') or context.get('from_user_id') + + # 1.create session + chat_session_id = sessions.get(user_id) + if not chat_session_id: + self.new_session(context) + sessions[user_id] = context['chat_session_id'] + else: + context['chat_session_id'] = chat_session_id + + # 2.create chat + context['query'] = query + self.new_chat(context) + + # 3.query + context['reply'] = '' + self.query(context, 0, 0) + + return context['reply'] + + + def new_session(self, context): + data = { + "sessionName": "test session", + "timestamp": int(time.time() * 1000), + "deviceType": "pc" + } + res = requests.post(url=self.base_url+'/session/new', headers=self._create_header(), json=data) + print(res.headers) + context['chat_session_id'] = res.json()['data']['sessionId'] + + + def new_chat(self, context): + headers = self._create_header() + headers['Acs-Token'] = self.acs_token + data = { + "sessionId": context.get('chat_session_id'), + "text": context['query'], + "parentChatId": 0, + "type": 10, + "timestamp": int(time.time() * 1000), + "deviceType": "pc", + "code": 0, + "msg": "" + } + res = requests.post(url=self.base_url+'/chat/new', headers=headers, json=data) + context['chat_id'] = res.json()['data']['botChat']['id'] + context['parent_chat_id'] = res.json()['data']['botChat']['parent'] + + + def query(self, context, sentence_id, count): + headers = self._create_header() + headers['Acs-Token'] = self.acs_token + data = { + "chatId": context['chat_id'], + "parentChatId": context['parent_chat_id'], + "sentenceId": sentence_id, + "stop": 0, + "timestamp": 1679068791405, + "deviceType": "pc" + } + res = requests.post(url=self.base_url + '/chat/query', headers=headers, json=data).json() + + if res['data']['text'] != '': + context['reply'] += res['data']['text'] + + if res['data']['is_end'] == 1: + return + + if count > 10: + return + + if not res['data']['text']: + time.sleep(1) + return self.query(context, sentence_id, count+1) + + + def _create_header(self): + headers = { + 'Host': 'yiyan.baidu.com', + 'Origin': 'https://yiyan.baidu.com', + 'Referer': 'https://yiyan.baidu.com', + 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/105.0.0.0 Safari/537.36', + 'Content-Type': 'application/json', + 'Cookie': self.cookie + } + return headers diff --git a/model/model_factory.py b/model/model_factory.py index c52c4f6..87dc468 100644 --- a/model/model_factory.py +++ b/model/model_factory.py @@ -18,7 +18,7 @@ def create_bot(model_type): elif model_type == const.CHATGPT: # ChatGPT API (gpt-3.5-turbo) - from model.chatgpt.chatgpt_model import ChatGPTModel + from model.openai.chatgpt_model import ChatGPTModel return ChatGPTModel() raise RuntimeError diff --git a/model/chatgpt/chatgpt_model.py b/model/openai/chatgpt_model.py similarity index 100% rename from model/chatgpt/chatgpt_model.py rename to model/openai/chatgpt_model.py