mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-25 08:57:51 +08:00
Merge pull request #2670 from SgtPepper114/fix/gemini-dingtalk-image-inline
fix(gemini): 修复钉钉图片标记未转多模态导致的识图失效
This commit is contained in:
@@ -6,11 +6,14 @@ Google gemini bot
|
||||
"""
|
||||
# encoding:utf-8
|
||||
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import requests
|
||||
from models.bot import Bot
|
||||
import google.generativeai as genai
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
@@ -18,7 +21,6 @@ from common.log import logger
|
||||
from config import conf
|
||||
from models.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from models.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
@@ -43,6 +45,7 @@ class GoogleGeminiBot(Bot):
|
||||
self.api_base = "https://generativelanguage.googleapis.com"
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
session_id = None
|
||||
try:
|
||||
if context.type != ContextType.TEXT:
|
||||
logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
|
||||
@@ -50,43 +53,47 @@ class GoogleGeminiBot(Bot):
|
||||
logger.info(f"[Gemini] query={query}")
|
||||
session_id = context["session_id"]
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
|
||||
logger.debug(f"[Gemini] messages={gemini_messages}")
|
||||
genai.configure(api_key=self.api_key)
|
||||
model = genai.GenerativeModel(self.model)
|
||||
|
||||
# 添加安全设置
|
||||
safety_settings = {
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
|
||||
# 生成回复,包含安全设置
|
||||
response = model.generate_content(
|
||||
gemini_messages,
|
||||
safety_settings=safety_settings
|
||||
filtered_messages = self.filter_messages(session.messages)
|
||||
logger.debug(f"[Gemini] messages={filtered_messages}")
|
||||
|
||||
response = self.call_with_tools(
|
||||
messages=filtered_messages,
|
||||
tools=None,
|
||||
stream=False,
|
||||
model=self.model
|
||||
)
|
||||
if response.candidates and response.candidates[0].content:
|
||||
reply_text = response.candidates[0].content.parts[0].text
|
||||
logger.info(f"[Gemini] reply={reply_text}")
|
||||
self.sessions.session_reply(reply_text, session_id)
|
||||
return Reply(ReplyType.TEXT, reply_text)
|
||||
else:
|
||||
# 没有有效响应内容,可能内容被屏蔽,输出安全评分
|
||||
logger.warning("[Gemini] No valid response generated. Checking safety ratings.")
|
||||
if hasattr(response, 'candidates') and response.candidates:
|
||||
for rating in response.candidates[0].safety_ratings:
|
||||
logger.warning(f"Safety rating: {rating.category} - {rating.probability}")
|
||||
error_message = "No valid response generated due to safety constraints."
|
||||
|
||||
if isinstance(response, dict) and response.get("error"):
|
||||
error_message = response.get("message", "Failed to invoke [Gemini] api!")
|
||||
logger.error(f"[Gemini] API error: {error_message}")
|
||||
self.sessions.session_reply(error_message, session_id)
|
||||
return Reply(ReplyType.ERROR, error_message)
|
||||
|
||||
choices = response.get("choices", []) if isinstance(response, dict) else []
|
||||
if choices and choices[0].get("message"):
|
||||
reply_text = choices[0]["message"].get("content")
|
||||
if reply_text:
|
||||
logger.info(f"[Gemini] reply={reply_text}")
|
||||
self.sessions.session_reply(reply_text, session_id)
|
||||
return Reply(ReplyType.TEXT, reply_text)
|
||||
|
||||
logger.warning("[Gemini] No valid response generated. Checking safety ratings.")
|
||||
safety_ratings = response.get("safety_ratings", []) if isinstance(response, dict) else []
|
||||
if safety_ratings:
|
||||
for rating in safety_ratings:
|
||||
category = rating.get("category", "UNKNOWN")
|
||||
probability = rating.get("probability", "UNKNOWN")
|
||||
logger.warning(f"[Gemini] Safety rating: {category} - {probability}")
|
||||
|
||||
error_message = "No valid response generated due to safety constraints."
|
||||
self.sessions.session_reply(error_message, session_id)
|
||||
return Reply(ReplyType.ERROR, error_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Gemini] Error generating response: {str(e)}", exc_info=True)
|
||||
error_message = "Failed to invoke [Gemini] api!"
|
||||
self.sessions.session_reply(error_message, session_id)
|
||||
if session_id:
|
||||
self.sessions.session_reply(error_message, session_id)
|
||||
return Reply(ReplyType.ERROR, error_message)
|
||||
|
||||
def _convert_to_gemini_messages(self, messages: list):
|
||||
@@ -127,6 +134,93 @@ class GoogleGeminiBot(Bot):
|
||||
turn = "user"
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_paths_from_text(content: str):
|
||||
if not isinstance(content, str):
|
||||
return "", []
|
||||
pattern = r"\[图片:\s*([^\]]+)\]"
|
||||
image_paths = [m.strip().strip("'\"") for m in re.findall(pattern, content) if m.strip()]
|
||||
cleaned_text = re.sub(pattern, "", content)
|
||||
cleaned_text = re.sub(r"\n{3,}", "\n\n", cleaned_text).strip()
|
||||
return cleaned_text, image_paths
|
||||
|
||||
@staticmethod
|
||||
def _build_image_inline_part(image_path: str):
|
||||
if not image_path:
|
||||
return None
|
||||
try:
|
||||
if image_path.startswith("file://"):
|
||||
image_path = image_path[7:]
|
||||
|
||||
image_path = os.path.expanduser(image_path)
|
||||
if not os.path.exists(image_path):
|
||||
logger.warning(f"[Gemini] Image file not found: {image_path}")
|
||||
return None
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
|
||||
mime_type = mimetypes.guess_type(image_path)[0] or "image/png"
|
||||
if not mime_type.startswith("image/"):
|
||||
mime_type = "image/png"
|
||||
|
||||
return {
|
||||
"inlineData": {
|
||||
"mimeType": mime_type,
|
||||
"data": base64.b64encode(image_bytes).decode("utf-8")
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"[Gemini] Failed to build inline image part from path={image_path}, err={e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_inline_part_from_image_url(image_url):
|
||||
if not image_url:
|
||||
return None
|
||||
|
||||
if isinstance(image_url, dict):
|
||||
image_url = image_url.get("url")
|
||||
if not image_url or not isinstance(image_url, str):
|
||||
return None
|
||||
|
||||
if image_url.startswith("data:"):
|
||||
match = re.match(r"^data:([^;]+);base64,(.+)$", image_url, re.DOTALL)
|
||||
if not match:
|
||||
logger.warning("[Gemini] Invalid data URL for image block")
|
||||
return None
|
||||
return {
|
||||
"inlineData": {
|
||||
"mimeType": match.group(1),
|
||||
"data": match.group(2).strip()
|
||||
}
|
||||
}
|
||||
|
||||
if image_url.startswith("file://") or os.path.exists(os.path.expanduser(image_url)):
|
||||
return GoogleGeminiBot._build_image_inline_part(image_url)
|
||||
|
||||
if image_url.startswith("http://") or image_url.startswith("https://"):
|
||||
try:
|
||||
response = requests.get(image_url, timeout=20)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"[Gemini] Failed to fetch remote image: status={response.status_code}, url={image_url}")
|
||||
return None
|
||||
mime_type = response.headers.get("Content-Type", "image/png").split(";")[0].strip()
|
||||
if not mime_type.startswith("image/"):
|
||||
mime_type = "image/png"
|
||||
return {
|
||||
"inlineData": {
|
||||
"mimeType": mime_type,
|
||||
"data": base64.b64encode(response.content).decode("utf-8")
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"[Gemini] Failed to download remote image: url={image_url}, err={e}")
|
||||
return None
|
||||
|
||||
logger.warning(f"[Gemini] Unsupported image URL format: {image_url[:120]}")
|
||||
return None
|
||||
|
||||
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
"""
|
||||
Call Gemini API with tool support using REST API (following official docs)
|
||||
@@ -145,6 +239,15 @@ class GoogleGeminiBot(Bot):
|
||||
|
||||
# Build REST API payload
|
||||
payload = {"contents": []}
|
||||
inline_image_count = 0
|
||||
|
||||
# Keep legacy behavior: disable Gemini safety blocking like old SDK path.
|
||||
payload["safetySettings"] = [
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
|
||||
# Extract and set system instruction
|
||||
system_prompt = kwargs.get("system", "")
|
||||
@@ -174,8 +277,19 @@ class GoogleGeminiBot(Bot):
|
||||
parts = []
|
||||
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
parts.append({"text": content})
|
||||
# Text with optional [图片: /path/to/file] markers
|
||||
cleaned_text, image_paths = self._extract_image_paths_from_text(content)
|
||||
if cleaned_text:
|
||||
parts.append({"text": cleaned_text})
|
||||
image_added = False
|
||||
for image_path in image_paths:
|
||||
image_part = self._build_image_inline_part(image_path)
|
||||
if image_part:
|
||||
parts.append(image_part)
|
||||
image_added = True
|
||||
inline_image_count += 1
|
||||
if not cleaned_text and not image_added and content:
|
||||
parts.append({"text": content})
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List of content blocks (Claude format)
|
||||
@@ -188,8 +302,39 @@ class GoogleGeminiBot(Bot):
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == "text":
|
||||
# Text block
|
||||
parts.append({"text": block.get("text", "")})
|
||||
# Text block with optional image markers
|
||||
block_text = block.get("text", "")
|
||||
cleaned_text, image_paths = self._extract_image_paths_from_text(block_text)
|
||||
if cleaned_text:
|
||||
parts.append({"text": cleaned_text})
|
||||
for image_path in image_paths:
|
||||
image_part = self._build_image_inline_part(image_path)
|
||||
if image_part:
|
||||
parts.append(image_part)
|
||||
|
||||
elif block_type in ["image", "image_url"]:
|
||||
# OpenAI format: {"type":"image_url","image_url":{"url":"..."}}
|
||||
# Claude format: {"type":"image","source":{"type":"base64","media_type":"...","data":"..."}}
|
||||
image_part = None
|
||||
if block_type == "image":
|
||||
source = block.get("source", {})
|
||||
if isinstance(source, dict) and source.get("type") == "base64" and source.get("data"):
|
||||
image_part = {
|
||||
"inlineData": {
|
||||
"mimeType": source.get("media_type", "image/png"),
|
||||
"data": source.get("data")
|
||||
}
|
||||
}
|
||||
elif block.get("image_url"):
|
||||
image_part = self._build_inline_part_from_image_url(block.get("image_url"))
|
||||
else:
|
||||
image_part = self._build_inline_part_from_image_url(block.get("image_url"))
|
||||
|
||||
if image_part:
|
||||
parts.append(image_part)
|
||||
inline_image_count += 1
|
||||
else:
|
||||
logger.warning(f"[Gemini] Skip invalid image block: {str(block)[:200]}")
|
||||
|
||||
elif block_type == "tool_result":
|
||||
# Convert Claude tool_result to Gemini functionResponse
|
||||
@@ -237,6 +382,9 @@ class GoogleGeminiBot(Bot):
|
||||
"role": gemini_role,
|
||||
"parts": parts
|
||||
})
|
||||
|
||||
if inline_image_count > 0:
|
||||
logger.info(f"[Gemini] Multimodal request includes {inline_image_count} image part(s)")
|
||||
|
||||
# Generation config
|
||||
gen_config = {}
|
||||
@@ -363,15 +511,18 @@ class GoogleGeminiBot(Bot):
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates:
|
||||
logger.warning("[Gemini] No candidates in response")
|
||||
prompt_feedback = data.get("promptFeedback", {})
|
||||
return {
|
||||
"error": True,
|
||||
"message": "No candidates in response",
|
||||
"status_code": 500
|
||||
"status_code": 500,
|
||||
"safety_ratings": prompt_feedback.get("safetyRatings", [])
|
||||
}
|
||||
|
||||
candidate = candidates[0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
safety_ratings = candidate.get("safetyRatings", [])
|
||||
|
||||
logger.debug(f"[Gemini] Candidate parts count: {len(parts)}")
|
||||
|
||||
@@ -419,7 +570,8 @@ class GoogleGeminiBot(Bot):
|
||||
"message": message_dict,
|
||||
"finish_reason": "tool_calls" if tool_calls else "stop"
|
||||
}],
|
||||
"usage": data.get("usageMetadata", {})
|
||||
"usage": data.get("usageMetadata", {}),
|
||||
"safety_ratings": safety_ratings
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user