From f20d7043900791cc256fa8cf64d647dad6de41a7 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Fri, 20 Sep 2024 09:10:21 +0100 Subject: [PATCH] fix: gemini doesn't receive system messages; change session to gpt method, add system messages as user messages to the gemini, and logging historical messages --- bot/gemini/google_gemini_bot.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/bot/gemini/google_gemini_bot.py b/bot/gemini/google_gemini_bot.py index 8a4100a..cd5eacc 100644 --- a/bot/gemini/google_gemini_bot.py +++ b/bot/gemini/google_gemini_bot.py @@ -13,7 +13,7 @@ from bridge.context import ContextType, Context from bridge.reply import Reply, ReplyType from common.log import logger from config import conf -from bot.baidu.baidu_wenxin_session import BaiduWenxinSession +from bot.chatgpt.chat_gpt_session import ChatGPTSession # OpenAI对话模型API (可用) @@ -23,7 +23,7 @@ class GoogleGeminiBot(Bot): super().__init__() self.api_key = conf().get("gemini_api_key") # 复用文心的token计算方式 - self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo") + self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo") self.model = conf().get("model") or "gemini-pro" if self.model == "gemini": self.model = "gemini-pro" @@ -36,6 +36,7 @@ class GoogleGeminiBot(Bot): session_id = context["session_id"] session = self.sessions.session_query(query, session_id) gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages)) + logger.info(f"[Gemini] messages={gemini_messages}") genai.configure(api_key=self.api_key) model = genai.GenerativeModel(self.model) response = model.generate_content(gemini_messages) @@ -55,6 +56,8 @@ class GoogleGeminiBot(Bot): role = "user" elif msg.get("role") == "assistant": role = "model" + elif msg.get("role") == "system": + role = "user" else: continue res.append({ @@ -71,7 +74,11 @@ class GoogleGeminiBot(Bot): return res for i in range(len(messages) - 1, -1, -1): message = messages[i] - if message.get("role") != turn: + role = message.get("role") + if role == "system": + res.insert(0, message) + continue + if role != turn: continue res.insert(0, message) if turn == "user":