mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-01-19 01:21:01 +08:00
116 lines
4.7 KiB
Python
116 lines
4.7 KiB
Python
"""
|
|
Google gemini bot
|
|
|
|
@author zhayujie
|
|
@Date 2023/12/15
|
|
"""
|
|
# encoding:utf-8
|
|
|
|
from bot.bot import Bot
|
|
import google.generativeai as genai
|
|
from bot.session_manager import SessionManager
|
|
from bridge.context import ContextType, Context
|
|
from bridge.reply import Reply, ReplyType
|
|
from common.log import logger
|
|
from config import conf
|
|
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
|
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
|
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
|
|
|
|
|
# OpenAI对话模型API (可用)
|
|
class GoogleGeminiBot(Bot):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.api_key = conf().get("gemini_api_key")
|
|
# 复用chatGPT的token计算方式
|
|
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
|
self.model = conf().get("model") or "gemini-pro"
|
|
if self.model == "gemini":
|
|
self.model = "gemini-pro"
|
|
def reply(self, query, context: Context = None) -> Reply:
|
|
try:
|
|
if context.type != ContextType.TEXT:
|
|
logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
|
|
return Reply(ReplyType.TEXT, None)
|
|
logger.info(f"[Gemini] query={query}")
|
|
session_id = context["session_id"]
|
|
session = self.sessions.session_query(query, session_id)
|
|
gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
|
|
logger.debug(f"[Gemini] messages={gemini_messages}")
|
|
genai.configure(api_key=self.api_key)
|
|
model = genai.GenerativeModel(self.model)
|
|
|
|
# 添加安全设置
|
|
safety_settings = {
|
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
|
}
|
|
|
|
# 生成回复,包含安全设置
|
|
response = model.generate_content(
|
|
gemini_messages,
|
|
safety_settings=safety_settings
|
|
)
|
|
if response.candidates and response.candidates[0].content:
|
|
reply_text = response.candidates[0].content.parts[0].text
|
|
logger.info(f"[Gemini] reply={reply_text}")
|
|
self.sessions.session_reply(reply_text, session_id)
|
|
return Reply(ReplyType.TEXT, reply_text)
|
|
else:
|
|
# 没有有效响应内容,可能内容被屏蔽,输出安全评分
|
|
logger.warning("[Gemini] No valid response generated. Checking safety ratings.")
|
|
if hasattr(response, 'candidates') and response.candidates:
|
|
for rating in response.candidates[0].safety_ratings:
|
|
logger.warning(f"Safety rating: {rating.category} - {rating.probability}")
|
|
error_message = "No valid response generated due to safety constraints."
|
|
self.sessions.session_reply(error_message, session_id)
|
|
return Reply(ReplyType.ERROR, error_message)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[Gemini] Error generating response: {str(e)}", exc_info=True)
|
|
error_message = "Failed to invoke [Gemini] api!"
|
|
self.sessions.session_reply(error_message, session_id)
|
|
return Reply(ReplyType.ERROR, error_message)
|
|
|
|
def _convert_to_gemini_messages(self, messages: list):
|
|
res = []
|
|
for msg in messages:
|
|
if msg.get("role") == "user":
|
|
role = "user"
|
|
elif msg.get("role") == "assistant":
|
|
role = "model"
|
|
elif msg.get("role") == "system":
|
|
role = "user"
|
|
else:
|
|
continue
|
|
res.append({
|
|
"role": role,
|
|
"parts": [{"text": msg.get("content")}]
|
|
})
|
|
return res
|
|
|
|
@staticmethod
|
|
def filter_messages(messages: list):
|
|
res = []
|
|
turn = "user"
|
|
if not messages:
|
|
return res
|
|
for i in range(len(messages) - 1, -1, -1):
|
|
message = messages[i]
|
|
role = message.get("role")
|
|
if role == "system":
|
|
res.insert(0, message)
|
|
continue
|
|
if role != turn:
|
|
continue
|
|
res.insert(0, message)
|
|
if turn == "user":
|
|
turn = "assistant"
|
|
elif turn == "assistant":
|
|
turn = "user"
|
|
return res
|