From fa6689ec55151b949dbab4191ec0856766ec3ea9 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Sun, 19 Mar 2023 00:46:56 +0800 Subject: [PATCH] fix: acquire complete reply content --- channel/wechat/wechat_channel.py | 1 - model/baidu/yiyan_model.py | 33 +++++++++++++++++++++----------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index 5a215c3..f4a5585 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -13,7 +13,6 @@ from common.log import logger from common import const from config import channel_conf_val import requests -from urllib.parse import urlencode from common.sensitive_word import SensitiveWord diff --git a/model/baidu/yiyan_model.py b/model/baidu/yiyan_model.py index 6d0004a..3029523 100644 --- a/model/baidu/yiyan_model.py +++ b/model/baidu/yiyan_model.py @@ -3,10 +3,9 @@ from model.model import Model from config import model_conf from common import const -from common import log +from common.log import logger import requests import time -import json sessions = {} @@ -17,7 +16,9 @@ class YiyanModel(Model): self.base_url = 'https://yiyan.baidu.com/eb' def reply(self, query, context=None): + logger.info("[BAIDU] query={}".format(query)) user_id = context.get('session_id') or context.get('from_user_id') + context['query'] = query # 1.create session chat_session_id = sessions.get(user_id) @@ -28,8 +29,9 @@ class YiyanModel(Model): context['chat_session_id'] = chat_session_id # 2.create chat - context['query'] = query - self.new_chat(context) + flag = self.new_chat(context) + if not flag: + return "创建会话失败,请稍后再试" # 3.query context['reply'] = '' @@ -40,13 +42,14 @@ class YiyanModel(Model): def new_session(self, context): data = { - "sessionName": "test session", + "sessionName": context['query'], "timestamp": int(time.time() * 1000), "deviceType": "pc" } res = requests.post(url=self.base_url+'/session/new', headers=self._create_header(), json=data) # print(res.headers) context['chat_session_id'] = res.json()['data']['sessionId'] + logger.info("[BAIDU] newSession: id={}".format(context['chat_session_id'])) def new_chat(self, context): @@ -62,10 +65,13 @@ class YiyanModel(Model): "code": 0, "msg": "" } - res = requests.post(url=self.base_url+'/chat/new', headers=headers, json=data) - print(res.text) - context['chat_id'] = res.json()['data']['botChat']['id'] - context['parent_chat_id'] = res.json()['data']['botChat']['parent'] + res = requests.post(url=self.base_url+'/chat/new', headers=headers, json=data).json() + if res['code'] != 0: + logger.error("[BAIDU] New chat error, msg={}", res['msg']) + return False + context['chat_id'] = res['data']['botChat']['id'] + context['parent_chat_id'] = res['data']['botChat']['parent'] + return True def query(self, context, sentence_id, count): @@ -79,10 +85,13 @@ class YiyanModel(Model): "timestamp": 1679068791405, "deviceType": "pc" } - res = requests.post(url=self.base_url + '/chat/query', headers=headers, json=data).json() + res = requests.post(url=self.base_url + '/chat/query', headers=headers, json=data) + logger.debug("[BAIDU] query: sent_id={}, count={}, res={}".format(sentence_id, count, res.text)) + res = res.json() if res['data']['text'] != '': context['reply'] += res['data']['text'] + # logger.debug("[BAIDU] query: sent_id={}, reply={}".format(sentence_id, res['data']['text'])) if res['data']['is_end'] == 1: return @@ -90,9 +99,11 @@ class YiyanModel(Model): if count > 10: return + time.sleep(1) if not res['data']['text']: - time.sleep(1) return self.query(context, sentence_id, count+1) + else: + return self.query(context, sentence_id+1, count+1) def _create_header(self):