diff --git a/README.md b/README.md index 20e167b..e78cd3e 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,11 @@ - [x] **规则定制化:** 支持私聊中按指定规则触发自动回复,支持对群组设置自动回复白名单 - [x] **多账号:** 支持多微信账号同时运行 - [x] **图片生成:** 支持根据描述生成图片,并自动发送至个人聊天或群聊 +- [x] **上下文记忆**:支持多轮对话记忆,且为每个好友维护独立的上下会话 # 更新日志 +>**2022.02.05:** 在openai官方接口方案中 (GPT-3模型) 实现上下文对话 >**2022.12.19:** 引入 [itchat-uos](https://github.com/why2lyj/ItChat-UOS) 替换 itchat,解决由于不能登录网页微信而无法使用的问题,且解决Python3.9的兼容问题 @@ -85,7 +87,8 @@ cp config-template.json config.json "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 - "image_create_prefix": ["画", "看", "找"] # 开启图片回复的前缀 + "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 + "conversation_max_tokens": 3000 # 支持上下文记忆的最多字符数 } ``` **配置说明:** @@ -105,6 +108,7 @@ cp config-template.json config.json + 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix ` + 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。 ++ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话) ## 运行 diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py index dc554d7..a6858c0 100644 --- a/bot/openai/open_ai_bot.py +++ b/bot/openai/open_ai_bot.py @@ -5,6 +5,7 @@ from config import conf from common.log import logger import openai +user_session = dict() # OpenAI对话模型API (可用) class OpenAIBot(Bot): @@ -12,17 +13,26 @@ class OpenAIBot(Bot): openai.api_key = conf().get('open_ai_api_key') def reply(self, query, context=None): - # auto append question mark - query = self.append_question_mark(query) # acquire reply content if not context or not context.get('type') or context.get('type') == 'TEXT': - return self.reply_text(query) + logger.info("[OPEN_AI] query={}".format(query)) + from_user_id = context['from_user_id'] + if query == '#清除记忆': + Session.clear_session(from_user_id) + return '记忆已清除' + + new_query = Session.build_session_query(query, from_user_id) + logger.debug("[OPEN_AI] session query={}".format(new_query)) + + reply_content = self.reply_text(new_query, query) + Session.save_session(query, reply_content, from_user_id) + return reply_content + elif context.get('type', None) == 'IMAGE_CREATE': return self.create_img(query) - def reply_text(self, query): - logger.info("[OPEN_AI] query={}".format(query)) + def reply_text(self, query, origin_query): try: response = openai.Completion.create( model="text-davinci-003", # 对话模型的名称 @@ -34,7 +44,7 @@ class OpenAIBot(Bot): presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 stop=["#"] ) - res_content = response.choices[0]["text"].strip() + res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>") except Exception as e: logger.exception(e) return None @@ -93,3 +103,68 @@ class OpenAIBot(Bot): if query.endswith(symbol): return query return query + "?" + + +class Session(object): + @staticmethod + def build_session_query(query, user_id): + ''' + build query with conversation history + e.g. Q: xxx + A: xxx + Q: xxx + :param query: query content + :param user_id: from user id + :return: query content with conversaction + ''' + new_query = "" + session = user_session.get(user_id, None) + if session: + for conversation in session: + new_query += "Q: " + conversation["question"] + "\n\n\nA: " + conversation["answer"] + "<|im_end|>\n" + new_query += "Q: " + query + "\nA: " + return new_query + else: + return "Q: " + query + "\nA: " + + @staticmethod + def save_session(query, answer, user_id): + max_tokens = conf().get("conversation_max_tokens") + if not max_tokens: + # default 3000 + max_tokens = 3000 + conversation = dict() + conversation["question"] = query + conversation["answer"] = answer + session = user_session.get(user_id) + if session: + # append conversation + session.append(conversation) + else: + # create session + queue = list() + queue.append(conversation) + user_session[user_id] = queue + + # discard exceed limit conversation + Session.discard_exceed_conversation(user_session[user_id], max_tokens) + + + @staticmethod + def discard_exceed_conversation(session, max_tokens): + count = 0 + count_list = list() + for i in range(len(session)-1, -1, -1): + # count tokens of conversation list + history_conv = session[i] + count += len(history_conv["question"]) + len(history_conv["answer"]) + count_list.append(count) + + for c in count_list: + if c > max_tokens: + # pop first conversation + session.pop(0) + + @staticmethod + def clear_session(user_id): + user_session[user_id] = [] diff --git a/config-template.json b/config-template.json index 3af5d23..84d660c 100644 --- a/config-template.json +++ b/config-template.json @@ -4,5 +4,6 @@ "single_chat_reply_prefix": "[bot] ", "group_chat_prefix": ["@bot"], "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], - "image_create_prefix": ["画", "看", "找"] + "image_create_prefix": ["画", "看", "找"], + "conversation_max_tokens": 3000 }