mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-26 16:16:21 +08:00
private openai_api_key
This commit is contained in:
@@ -13,10 +13,13 @@ from common.expired_dict import ExpiredDict
|
||||
import openai
|
||||
import openai.error
|
||||
import time
|
||||
import redis
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot,OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# set the default api_key
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
@@ -33,6 +36,7 @@ class ChatGPTBot(Bot,OpenAIImage):
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[CHATGPT] query={}".format(query))
|
||||
|
||||
|
||||
session_id = context['session_id']
|
||||
reply = None
|
||||
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
|
||||
@@ -50,11 +54,13 @@ class ChatGPTBot(Bot,OpenAIImage):
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
||||
|
||||
api_key = context.get('openai_api_key')
|
||||
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, session_id, 0)
|
||||
reply_content = self.reply_text(session, session_id, api_key, 0)
|
||||
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
|
||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
@@ -90,7 +96,7 @@ class ChatGPTBot(Bot,OpenAIImage):
|
||||
"timeout": 120, #重试超时时间,在这个时间内,将会自动重试
|
||||
}
|
||||
|
||||
def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict:
|
||||
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict:
|
||||
'''
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
@@ -101,8 +107,9 @@ class ChatGPTBot(Bot,OpenAIImage):
|
||||
try:
|
||||
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
|
||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=session.messages, **self.compose_args()
|
||||
api_key=api_key, messages=session.messages, **self.compose_args()
|
||||
)
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {"total_tokens": response["usage"]["total_tokens"],
|
||||
@@ -118,21 +125,21 @@ class ChatGPTBot(Bot,OpenAIImage):
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
||||
result['content'] = "我没有收到你的消息"
|
||||
result['content'] = "服务器出现问题"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result['content'] = "我连接不到你的网络"
|
||||
result['content'] = "网络连接出现问题"
|
||||
else:
|
||||
logger.warn("[CHATGPT] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session_id)
|
||||
|
||||
result['content'] = str(e)
|
||||
if need_retry:
|
||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(session, session_id, retry_count+1)
|
||||
return self.reply_text(session, session_id, api_key, retry_count+1)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from bridge.reply import *
|
||||
from bridge.context import *
|
||||
from plugins import *
|
||||
import traceback
|
||||
import redis
|
||||
|
||||
class WechatMPServer():
|
||||
def __init__(self):
|
||||
@@ -82,7 +83,6 @@ class WechatMPChannel(Channel):
|
||||
global cache_dict
|
||||
try:
|
||||
reply = Reply()
|
||||
|
||||
logger.debug('[wechatmp] ready to handle context: {}'.format(context))
|
||||
|
||||
# reply的构建步骤
|
||||
@@ -134,6 +134,8 @@ class WechatMPChannel(Channel):
|
||||
self.send(reply, context['receiver'])
|
||||
else:
|
||||
cache_dict[context['receiver']] = (1, "No reply")
|
||||
|
||||
logger.info("[threaded] Get reply for {}: {} \nA: {}".format(context['receiver'], context.content, reply.content))
|
||||
except Exception as exc:
|
||||
print(traceback.format_exc())
|
||||
cache_dict[context['receiver']] = (1, "ERROR")
|
||||
@@ -171,6 +173,14 @@ class WechatMPChannel(Channel):
|
||||
|
||||
context = Context()
|
||||
context.kwargs = {'isgroup': False, 'receiver': fromUser, 'session_id': fromUser}
|
||||
|
||||
R = redis.Redis(host='localhost', port=6379, db=0)
|
||||
user_openai_api_key = "openai_api_key_" + fromUser
|
||||
api_key = R.get(user_openai_api_key)
|
||||
if api_key != None:
|
||||
api_key = api_key.decode("utf-8")
|
||||
context['openai_api_key'] = api_key # None or user openai_api_key
|
||||
|
||||
img_match_prefix = check_prefix(message, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
message = message.replace(img_match_prefix, '', 1).strip()
|
||||
@@ -240,7 +250,7 @@ class WechatMPChannel(Channel):
|
||||
if cnt == 45:
|
||||
# Have waiting for 3x5 seconds
|
||||
# return timeout message
|
||||
reply_text = "【服务器有点忙,回复任意文字再次尝试】"
|
||||
reply_text = "【正在响应中,回复任意文字尝试获取回复】"
|
||||
logger.info("[wechatmp] Three queries has finished For {}: {}".format(fromUser, message_id))
|
||||
replyPost = reply.TextMsg(fromUser, toUser, reply_text).send()
|
||||
return replyPost
|
||||
|
||||
@@ -29,6 +29,15 @@ COMMANDS = {
|
||||
"args": ["口令"],
|
||||
"desc": "管理员认证",
|
||||
},
|
||||
"set_openai_api_key": {
|
||||
"alias": ["set_openai_api_key"],
|
||||
"args": ["api_key"],
|
||||
"desc": "设置你的OpenAI私有api_key",
|
||||
},
|
||||
"reset_openai_api_key": {
|
||||
"alias": ["reset_openai_api_key"],
|
||||
"desc": "重置为默认的api_key",
|
||||
},
|
||||
# "id": {
|
||||
# "alias": ["id", "用户"],
|
||||
# "desc": "获取用户id", #目前无实际意义
|
||||
@@ -99,7 +108,7 @@ def get_help_text(isadmin, isgroup):
|
||||
alias=["#"+a for a in info['alias']]
|
||||
help_text += f"{','.join(alias)} "
|
||||
if 'args' in info:
|
||||
args=["{"+a+"}" for a in info['args']]
|
||||
args=["'"+a+"'" for a in info['args']]
|
||||
help_text += f"{' '.join(args)} "
|
||||
help_text += f": {info['desc']}\n"
|
||||
|
||||
@@ -162,7 +171,7 @@ class Godcmd(Plugin):
|
||||
bottype = Bridge().get_bot_type("chat")
|
||||
bot = Bridge().get_bot("chat")
|
||||
# 将命令和参数分割
|
||||
command_parts = content[1:].split(" ")
|
||||
command_parts = content[1:].strip().split(" ")
|
||||
cmd = command_parts[0]
|
||||
args = command_parts[1:]
|
||||
isadmin=False
|
||||
@@ -184,6 +193,22 @@ class Godcmd(Plugin):
|
||||
ok, result = True, PluginManager().instances[name].get_help_text(verbose=True)
|
||||
else:
|
||||
ok, result = False, "unknown args"
|
||||
elif cmd == "set_openai_api_key":
|
||||
if len(args) == 1:
|
||||
import redis
|
||||
R = redis.Redis(host='localhost', port=6379, db=0)
|
||||
user_openai_api_key = "openai_api_key_" + user
|
||||
R.set(user_openai_api_key, args[0])
|
||||
# R.sadd("openai_api_key", args[0])
|
||||
ok, result = True, "你的OpenAI私有api_key已设置为" + args[0]
|
||||
else:
|
||||
ok, result = False, "请提供一个api_key"
|
||||
elif cmd == "reset_openai_api_key":
|
||||
import redis
|
||||
R = redis.Redis(host='localhost', port=6379, db=0)
|
||||
user_openai_api_key = "openai_api_key_" + user
|
||||
R.delete(user_openai_api_key)
|
||||
ok, result = True, "OpenAI的api_key已重置"
|
||||
# elif cmd == "helpp":
|
||||
# if len(args) != 1:
|
||||
# ok, result = False, "请提供插件名"
|
||||
|
||||
Reference in New Issue
Block a user