mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-24 08:19:49 +08:00
openai 接口返回token数量来修剪会话长度
This commit is contained in:
@@ -6,7 +6,6 @@ from common.log import logger
|
||||
from common.expired_dict import ExpiredDict
|
||||
import openai
|
||||
import time
|
||||
import json
|
||||
|
||||
if conf().get('expires_in_seconds'):
|
||||
user_session = ExpiredDict(conf().get('expires_in_seconds'))
|
||||
@@ -44,12 +43,19 @@ class ChatGPTBot(Bot):
|
||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
||||
if reply_content:
|
||||
Session.save_session(query, reply_content, from_user_id)
|
||||
return reply_content
|
||||
return reply_content[1]
|
||||
|
||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||
return self.create_img(query, 0)
|
||||
|
||||
def reply_text(self, query, user_id, retry_count=0):
|
||||
'''
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param query: query content
|
||||
:param user_id: from user id
|
||||
:param retry_count: retry count
|
||||
:return: [0]-tokens used and [1]-answer
|
||||
'''
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo", # 对话模型的名称
|
||||
@@ -62,8 +68,8 @@ class ChatGPTBot(Bot):
|
||||
)
|
||||
# res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
||||
logger.info(response.choices[0]['message']['content'])
|
||||
# log.info("[OPEN_AI] reply={}".format(res_content))
|
||||
return response.choices[0]['message']['content']
|
||||
|
||||
return response["usage"]["prompt_tokens"],response.choices[0]['message']['content']
|
||||
except openai.error.RateLimitError as e:
|
||||
# rate limit exception
|
||||
logger.warn(e)
|
||||
@@ -72,21 +78,21 @@ class ChatGPTBot(Bot):
|
||||
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, user_id, retry_count+1)
|
||||
else:
|
||||
return "提问太快啦,请休息一下再问我吧"
|
||||
return 0,"提问太快啦,请休息一下再问我吧"
|
||||
except openai.error.APIConnectionError as e:
|
||||
# api connection exception
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] APIConnection failed")
|
||||
return "我连接不到你的网络"
|
||||
return 0,"我连接不到你的网络"
|
||||
except openai.error.Timeout as e:
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] Timeout")
|
||||
return "我没有收到你的消息"
|
||||
return 0,"我没有收到你的消息"
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
logger.exception(e)
|
||||
Session.clear_session(user_id)
|
||||
return "请再问我一次吧"
|
||||
return 0,"请再问我一次吧"
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
@@ -142,31 +148,27 @@ class Session(object):
|
||||
if not max_tokens:
|
||||
# default 3000
|
||||
max_tokens = 1000
|
||||
max_tokens=int(max_tokens)
|
||||
|
||||
session = user_session.get(user_id)
|
||||
if session:
|
||||
# append conversation
|
||||
gpt_item = {'role': 'assistant', 'content': answer}
|
||||
gpt_item = {'role': 'assistant', 'content': answer[1]}
|
||||
session.append(gpt_item)
|
||||
|
||||
# discard exceed limit conversation
|
||||
Session.discard_exceed_conversation(user_session[user_id], max_tokens)
|
||||
used_tokens=int(answer[0])
|
||||
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
|
||||
|
||||
@staticmethod
|
||||
def discard_exceed_conversation(session, max_tokens):
|
||||
count = 0
|
||||
count_list = list()
|
||||
for i in range(len(session)-1, -1, -1):
|
||||
# count tokens of conversation list
|
||||
history_conv = session[i]
|
||||
tokens=json.dumps(history_conv).split()
|
||||
count += len(tokens)
|
||||
count_list.append(count)
|
||||
while used_tokens > max_tokens:
|
||||
# pop first conversation
|
||||
if len(session) > 0:
|
||||
session.pop(0)
|
||||
else:
|
||||
break
|
||||
|
||||
used_tokens=used_tokens-max_tokens
|
||||
|
||||
for c in count_list:
|
||||
if c > max_tokens:
|
||||
# pop first conversation
|
||||
session.pop(0)
|
||||
|
||||
@staticmethod
|
||||
def clear_session(user_id):
|
||||
|
||||
Reference in New Issue
Block a user