fix: acquire complete reply content

This commit is contained in:
zhayujie
2023-03-19 00:46:56 +08:00
parent 100aa27587
commit fa6689ec55
2 changed files with 22 additions and 12 deletions

View File

@@ -13,7 +13,6 @@ from common.log import logger
from common import const
from config import channel_conf_val
import requests
from urllib.parse import urlencode
from common.sensitive_word import SensitiveWord

View File

@@ -3,10 +3,9 @@
from model.model import Model
from config import model_conf
from common import const
from common import log
from common.log import logger
import requests
import time
import json
sessions = {}
@@ -17,7 +16,9 @@ class YiyanModel(Model):
self.base_url = 'https://yiyan.baidu.com/eb'
def reply(self, query, context=None):
logger.info("[BAIDU] query={}".format(query))
user_id = context.get('session_id') or context.get('from_user_id')
context['query'] = query
# 1.create session
chat_session_id = sessions.get(user_id)
@@ -28,8 +29,9 @@ class YiyanModel(Model):
context['chat_session_id'] = chat_session_id
# 2.create chat
context['query'] = query
self.new_chat(context)
flag = self.new_chat(context)
if not flag:
return "创建会话失败,请稍后再试"
# 3.query
context['reply'] = ''
@@ -40,13 +42,14 @@ class YiyanModel(Model):
def new_session(self, context):
data = {
"sessionName": "test session",
"sessionName": context['query'],
"timestamp": int(time.time() * 1000),
"deviceType": "pc"
}
res = requests.post(url=self.base_url+'/session/new', headers=self._create_header(), json=data)
# print(res.headers)
context['chat_session_id'] = res.json()['data']['sessionId']
logger.info("[BAIDU] newSession: id={}".format(context['chat_session_id']))
def new_chat(self, context):
@@ -62,10 +65,13 @@ class YiyanModel(Model):
"code": 0,
"msg": ""
}
res = requests.post(url=self.base_url+'/chat/new', headers=headers, json=data)
print(res.text)
context['chat_id'] = res.json()['data']['botChat']['id']
context['parent_chat_id'] = res.json()['data']['botChat']['parent']
res = requests.post(url=self.base_url+'/chat/new', headers=headers, json=data).json()
if res['code'] != 0:
logger.error("[BAIDU] New chat error, msg={}", res['msg'])
return False
context['chat_id'] = res['data']['botChat']['id']
context['parent_chat_id'] = res['data']['botChat']['parent']
return True
def query(self, context, sentence_id, count):
@@ -79,10 +85,13 @@ class YiyanModel(Model):
"timestamp": 1679068791405,
"deviceType": "pc"
}
res = requests.post(url=self.base_url + '/chat/query', headers=headers, json=data).json()
res = requests.post(url=self.base_url + '/chat/query', headers=headers, json=data)
logger.debug("[BAIDU] query: sent_id={}, count={}, res={}".format(sentence_id, count, res.text))
res = res.json()
if res['data']['text'] != '':
context['reply'] += res['data']['text']
# logger.debug("[BAIDU] query: sent_id={}, reply={}".format(sentence_id, res['data']['text']))
if res['data']['is_end'] == 1:
return
@@ -90,9 +99,11 @@ class YiyanModel(Model):
if count > 10:
return
time.sleep(1)
if not res['data']['text']:
time.sleep(1)
return self.query(context, sentence_id, count+1)
else:
return self.query(context, sentence_id+1, count+1)
def _create_header(self):