Files
chatgpt-on-wechat/models/gemini/google_gemini_bot.py
SgtPepper114 05a33042c8 fix(gemini): support dingtalk image markers as multimodal input
- parse [图片: path] markers in text and convert to Gemini inlineData parts

- unify reply path via call_with_tools to reuse multimodal conversion

- keep legacy safety behavior (BLOCK_NONE) and restore safety ratings logging on empty response

- add multimodal request image-part count log for debugging
2026-02-16 13:26:57 +00:00

838 lines
35 KiB
Python

"""
Google gemini bot
@author zhayujie
@Date 2023/12/15
"""
# encoding:utf-8
import base64
import json
import mimetypes
import os
import re
import time
import requests
from models.bot import Bot
from models.session_manager import SessionManager
from bridge.context import ContextType, Context
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf
from models.chatgpt.chat_gpt_session import ChatGPTSession
from models.baidu.baidu_wenxin_session import BaiduWenxinSession
# OpenAI对话模型API (可用)
class GoogleGeminiBot(Bot):
def __init__(self):
super().__init__()
self.api_key = conf().get("gemini_api_key")
# 复用chatGPT的token计算方式
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.model = conf().get("model") or "gemini-pro"
if self.model == "gemini":
self.model = "gemini-pro"
# 支持自定义API base地址
self.api_base = conf().get("gemini_api_base", "").strip()
if self.api_base:
# 移除末尾的斜杠
self.api_base = self.api_base.rstrip('/')
logger.info(f"[Gemini] Using custom API base: {self.api_base}")
else:
self.api_base = "https://generativelanguage.googleapis.com"
def reply(self, query, context: Context = None) -> Reply:
session_id = None
try:
if context.type != ContextType.TEXT:
logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
return Reply(ReplyType.TEXT, None)
logger.info(f"[Gemini] query={query}")
session_id = context["session_id"]
session = self.sessions.session_query(query, session_id)
filtered_messages = self.filter_messages(session.messages)
logger.debug(f"[Gemini] messages={filtered_messages}")
response = self.call_with_tools(
messages=filtered_messages,
tools=None,
stream=False,
model=self.model
)
if isinstance(response, dict) and response.get("error"):
error_message = response.get("message", "Failed to invoke [Gemini] api!")
logger.error(f"[Gemini] API error: {error_message}")
self.sessions.session_reply(error_message, session_id)
return Reply(ReplyType.ERROR, error_message)
choices = response.get("choices", []) if isinstance(response, dict) else []
if choices and choices[0].get("message"):
reply_text = choices[0]["message"].get("content")
if reply_text:
logger.info(f"[Gemini] reply={reply_text}")
self.sessions.session_reply(reply_text, session_id)
return Reply(ReplyType.TEXT, reply_text)
logger.warning("[Gemini] No valid response generated. Checking safety ratings.")
safety_ratings = response.get("safety_ratings", []) if isinstance(response, dict) else []
if safety_ratings:
for rating in safety_ratings:
category = rating.get("category", "UNKNOWN")
probability = rating.get("probability", "UNKNOWN")
logger.warning(f"[Gemini] Safety rating: {category} - {probability}")
error_message = "No valid response generated due to safety constraints."
self.sessions.session_reply(error_message, session_id)
return Reply(ReplyType.ERROR, error_message)
except Exception as e:
logger.error(f"[Gemini] Error generating response: {str(e)}", exc_info=True)
error_message = "Failed to invoke [Gemini] api!"
if session_id:
self.sessions.session_reply(error_message, session_id)
return Reply(ReplyType.ERROR, error_message)
def _convert_to_gemini_messages(self, messages: list):
res = []
for msg in messages:
if msg.get("role") == "user":
role = "user"
elif msg.get("role") == "assistant":
role = "model"
elif msg.get("role") == "system":
role = "user"
else:
continue
res.append({
"role": role,
"parts": [{"text": msg.get("content")}]
})
return res
@staticmethod
def filter_messages(messages: list):
res = []
turn = "user"
if not messages:
return res
for i in range(len(messages) - 1, -1, -1):
message = messages[i]
role = message.get("role")
if role == "system":
res.insert(0, message)
continue
if role != turn:
continue
res.insert(0, message)
if turn == "user":
turn = "assistant"
elif turn == "assistant":
turn = "user"
return res
@staticmethod
def _extract_image_paths_from_text(content: str):
if not isinstance(content, str):
return "", []
pattern = r"\[图片:\s*([^\]]+)\]"
image_paths = [m.strip().strip("'\"") for m in re.findall(pattern, content) if m.strip()]
cleaned_text = re.sub(pattern, "", content)
cleaned_text = re.sub(r"\n{3,}", "\n\n", cleaned_text).strip()
return cleaned_text, image_paths
@staticmethod
def _build_image_inline_part(image_path: str):
if not image_path:
return None
try:
if image_path.startswith("file://"):
image_path = image_path[7:]
image_path = os.path.expanduser(image_path)
if not os.path.exists(image_path):
logger.warning(f"[Gemini] Image file not found: {image_path}")
return None
with open(image_path, "rb") as f:
image_bytes = f.read()
mime_type = mimetypes.guess_type(image_path)[0] or "image/png"
if not mime_type.startswith("image/"):
mime_type = "image/png"
return {
"inlineData": {
"mimeType": mime_type,
"data": base64.b64encode(image_bytes).decode("utf-8")
}
}
except Exception as e:
logger.warning(f"[Gemini] Failed to build inline image part from path={image_path}, err={e}")
return None
@staticmethod
def _build_inline_part_from_image_url(image_url):
if not image_url:
return None
if isinstance(image_url, dict):
image_url = image_url.get("url")
if not image_url or not isinstance(image_url, str):
return None
if image_url.startswith("data:"):
match = re.match(r"^data:([^;]+);base64,(.+)$", image_url, re.DOTALL)
if not match:
logger.warning("[Gemini] Invalid data URL for image block")
return None
return {
"inlineData": {
"mimeType": match.group(1),
"data": match.group(2).strip()
}
}
if image_url.startswith("file://") or os.path.exists(os.path.expanduser(image_url)):
return GoogleGeminiBot._build_image_inline_part(image_url)
if image_url.startswith("http://") or image_url.startswith("https://"):
try:
response = requests.get(image_url, timeout=20)
if response.status_code != 200:
logger.warning(f"[Gemini] Failed to fetch remote image: status={response.status_code}, url={image_url}")
return None
mime_type = response.headers.get("Content-Type", "image/png").split(";")[0].strip()
if not mime_type.startswith("image/"):
mime_type = "image/png"
return {
"inlineData": {
"mimeType": mime_type,
"data": base64.b64encode(response.content).decode("utf-8")
}
}
except Exception as e:
logger.warning(f"[Gemini] Failed to download remote image: url={image_url}, err={e}")
return None
logger.warning(f"[Gemini] Unsupported image URL format: {image_url[:120]}")
return None
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
"""
Call Gemini API with tool support using REST API (following official docs)
Args:
messages: List of messages (OpenAI format)
tools: List of tool definitions (OpenAI/Claude format)
stream: Whether to use streaming
**kwargs: Additional parameters (system, max_tokens, temperature, etc.)
Returns:
Formatted response compatible with OpenAI format or generator for streaming
"""
try:
model_name = kwargs.get("model", self.model or "gemini-1.5-flash")
# Build REST API payload
payload = {"contents": []}
inline_image_count = 0
# Keep legacy behavior: disable Gemini safety blocking like old SDK path.
payload["safetySettings"] = [
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
]
# Extract and set system instruction
system_prompt = kwargs.get("system", "")
if not system_prompt:
for msg in messages:
if msg.get("role") == "system":
system_prompt = msg["content"]
break
if system_prompt:
payload["system_instruction"] = {
"parts": [{"text": system_prompt}]
}
# Convert messages to Gemini format
for msg in messages:
role = msg.get("role")
content = msg.get("content", "")
if role == "system":
continue
# Convert role
gemini_role = "user" if role in ["user", "tool"] else "model"
# Handle different content formats
parts = []
if isinstance(content, str):
# Text with optional [图片: /path/to/file] markers
cleaned_text, image_paths = self._extract_image_paths_from_text(content)
if cleaned_text:
parts.append({"text": cleaned_text})
image_added = False
for image_path in image_paths:
image_part = self._build_image_inline_part(image_path)
if image_part:
parts.append(image_part)
image_added = True
inline_image_count += 1
if not cleaned_text and not image_added and content:
parts.append({"text": content})
elif isinstance(content, list):
# List of content blocks (Claude format)
for block in content:
if not isinstance(block, dict):
if isinstance(block, str):
parts.append({"text": block})
continue
block_type = block.get("type")
if block_type == "text":
# Text block with optional image markers
block_text = block.get("text", "")
cleaned_text, image_paths = self._extract_image_paths_from_text(block_text)
if cleaned_text:
parts.append({"text": cleaned_text})
for image_path in image_paths:
image_part = self._build_image_inline_part(image_path)
if image_part:
parts.append(image_part)
elif block_type in ["image", "image_url"]:
# OpenAI format: {"type":"image_url","image_url":{"url":"..."}}
# Claude format: {"type":"image","source":{"type":"base64","media_type":"...","data":"..."}}
image_part = None
if block_type == "image":
source = block.get("source", {})
if isinstance(source, dict) and source.get("type") == "base64" and source.get("data"):
image_part = {
"inlineData": {
"mimeType": source.get("media_type", "image/png"),
"data": source.get("data")
}
}
elif block.get("image_url"):
image_part = self._build_inline_part_from_image_url(block.get("image_url"))
else:
image_part = self._build_inline_part_from_image_url(block.get("image_url"))
if image_part:
parts.append(image_part)
inline_image_count += 1
else:
logger.warning(f"[Gemini] Skip invalid image block: {str(block)[:200]}")
elif block_type == "tool_result":
# Convert Claude tool_result to Gemini functionResponse
tool_use_id = block.get("tool_use_id")
tool_content = block.get("content", "")
# Try to parse tool content as JSON
try:
if isinstance(tool_content, str):
tool_result_data = json.loads(tool_content)
else:
tool_result_data = tool_content
except:
tool_result_data = {"result": tool_content}
# Find the tool name from previous messages
# Look for the corresponding tool_call in model's message
tool_name = None
for prev_msg in reversed(messages):
if prev_msg.get("role") == "assistant":
prev_content = prev_msg.get("content", [])
if isinstance(prev_content, list):
for prev_block in prev_content:
if isinstance(prev_block, dict) and prev_block.get("type") == "tool_use":
if prev_block.get("id") == tool_use_id:
tool_name = prev_block.get("name")
break
if tool_name:
break
# Gemini functionResponse format
parts.append({
"functionResponse": {
"name": tool_name or "unknown",
"response": tool_result_data
}
})
elif "text" in block:
# Generic text field
parts.append({"text": block["text"]})
if parts:
payload["contents"].append({
"role": gemini_role,
"parts": parts
})
if inline_image_count > 0:
logger.info(f"[Gemini] Multimodal request includes {inline_image_count} image part(s)")
# Generation config
gen_config = {}
if kwargs.get("temperature") is not None:
gen_config["temperature"] = kwargs["temperature"]
if gen_config:
payload["generationConfig"] = gen_config
# Convert tools to Gemini format (REST API style)
if tools:
gemini_tools = self._convert_tools_to_gemini_rest_format(tools)
if gemini_tools:
payload["tools"] = gemini_tools
# Make REST API call
base_url = f"{self.api_base}/v1beta"
endpoint = f"{base_url}/models/{model_name}:generateContent"
if stream:
endpoint = f"{base_url}/models/{model_name}:streamGenerateContent?alt=sse"
headers = {
"x-goog-api-key": self.api_key,
"Content-Type": "application/json"
}
response = requests.post(
endpoint,
headers=headers,
json=payload,
stream=stream,
timeout=60
)
# Check HTTP status for stream mode (for non-stream, it's checked in handler)
if stream and response.status_code != 200:
error_text = response.text
logger.error(f"[Gemini] API error ({response.status_code}): {error_text}")
def error_generator():
yield {
"error": True,
"message": f"Gemini API error: {error_text}",
"status_code": response.status_code
}
return error_generator()
if stream:
return self._handle_gemini_rest_stream_response(response, model_name)
else:
return self._handle_gemini_rest_sync_response(response, model_name)
except Exception as e:
logger.error(f"[Gemini] call_with_tools error: {e}", exc_info=True)
error_msg = str(e) # Capture error message before creating generator
if stream:
def error_generator():
yield {
"error": True,
"message": error_msg,
"status_code": 500
}
return error_generator()
else:
return {
"error": True,
"message": str(e),
"status_code": 500
}
def _convert_tools_to_gemini_rest_format(self, tools_list):
"""
Convert tools to Gemini REST API format
Handles both OpenAI and Claude/Agent formats.
Returns: [{"functionDeclarations": [...]}]
"""
function_declarations = []
for tool in tools_list:
# Extract name, description, and parameters based on format
if tool.get("type") == "function":
# OpenAI format: {"type": "function", "function": {...}}
func = tool.get("function", {})
name = func.get("name")
description = func.get("description", "")
parameters = func.get("parameters", {})
else:
# Claude/Agent format: {"name": "...", "description": "...", "input_schema": {...}}
name = tool.get("name")
description = tool.get("description", "")
parameters = tool.get("input_schema", {})
if not name:
logger.warning(f"[Gemini] Skipping tool without name: {tool}")
continue
function_declarations.append({
"name": name,
"description": description,
"parameters": parameters
})
# All functionDeclarations must be in a single tools object (per Gemini REST API spec)
return [{
"functionDeclarations": function_declarations
}] if function_declarations else []
def _handle_gemini_rest_sync_response(self, response, model_name):
"""Handle Gemini REST API sync response and convert to OpenAI format"""
try:
if response.status_code != 200:
error_text = response.text
logger.error(f"[Gemini] API error ({response.status_code}): {error_text}")
return {
"error": True,
"message": f"Gemini API error: {error_text}",
"status_code": response.status_code
}
data = response.json()
logger.debug(f"[Gemini] Response data: {json.dumps(data, ensure_ascii=False)[:500]}")
# Extract from Gemini response format
candidates = data.get("candidates", [])
if not candidates:
logger.warning("[Gemini] No candidates in response")
prompt_feedback = data.get("promptFeedback", {})
return {
"error": True,
"message": "No candidates in response",
"status_code": 500,
"safety_ratings": prompt_feedback.get("safetyRatings", [])
}
candidate = candidates[0]
content = candidate.get("content", {})
parts = content.get("parts", [])
safety_ratings = candidate.get("safetyRatings", [])
logger.debug(f"[Gemini] Candidate parts count: {len(parts)}")
# Extract text and function calls
text_content = ""
tool_calls = []
for part in parts:
# Check for text
if "text" in part:
text_content += part["text"]
logger.debug(f"[Gemini] Text part: {part['text'][:100]}...")
# Check for functionCall (per REST API docs)
if "functionCall" in part:
fc = part["functionCall"]
logger.info(f"[Gemini] Function call detected: {fc.get('name')}")
tool_calls.append({
"id": f"call_{int(time.time() * 1000000)}",
"type": "function",
"function": {
"name": fc.get("name"),
"arguments": json.dumps(fc.get("args", {}))
}
})
logger.info(f"[Gemini] Response: text={len(text_content)} chars, tool_calls={len(tool_calls)}")
# Build OpenAI format response
message_dict = {
"role": "assistant",
"content": text_content or None
}
if tool_calls:
message_dict["tool_calls"] = tool_calls
return {
"id": f"chatcmpl-{time.time()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"message": message_dict,
"finish_reason": "tool_calls" if tool_calls else "stop"
}],
"usage": data.get("usageMetadata", {}),
"safety_ratings": safety_ratings
}
except Exception as e:
logger.error(f"[Gemini] sync response error: {e}", exc_info=True)
return {
"error": True,
"message": str(e),
"status_code": 500
}
def _handle_gemini_rest_stream_response(self, response, model_name):
"""Handle Gemini REST API stream response"""
try:
all_tool_calls = []
has_sent_tool_calls = False
has_content = False # Track if any content was sent
chunk_count = 0
last_finish_reason = None
last_safety_ratings = None
for line in response.iter_lines():
if not line:
continue
line = line.decode('utf-8')
# Skip SSE prefixes
if line.startswith('data: '):
line = line[6:]
if not line or line == '[DONE]':
continue
try:
chunk_data = json.loads(line)
chunk_count += 1
candidates = chunk_data.get("candidates", [])
if not candidates:
logger.debug("[Gemini] No candidates in chunk")
continue
candidate = candidates[0]
# 记录 finish_reason 和 safety_ratings
if "finishReason" in candidate:
last_finish_reason = candidate["finishReason"]
if "safetyRatings" in candidate:
last_safety_ratings = candidate["safetyRatings"]
content = candidate.get("content", {})
parts = content.get("parts", [])
if not parts:
logger.debug("[Gemini] No parts in candidate content")
# Stream text content
for part in parts:
if "text" in part and part["text"]:
has_content = True
yield {
"id": f"chatcmpl-{time.time()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": part["text"]},
"finish_reason": None
}]
}
# Collect function calls
if "functionCall" in part:
fc = part["functionCall"]
logger.info(f"[Gemini] Function call: {fc.get('name')}")
all_tool_calls.append({
"index": len(all_tool_calls), # Add index to differentiate multiple tool calls
"id": f"call_{int(time.time() * 1000000)}_{len(all_tool_calls)}",
"type": "function",
"function": {
"name": fc.get("name"),
"arguments": json.dumps(fc.get("args", {}))
}
})
except json.JSONDecodeError as je:
logger.debug(f"[Gemini] JSON decode error: {je}")
continue
# Send tool calls if any were collected
if all_tool_calls and not has_sent_tool_calls:
yield {
"id": f"chatcmpl-{time.time()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"tool_calls": all_tool_calls},
"finish_reason": None
}]
}
has_sent_tool_calls = True
# 如果返回空响应,记录详细警告
if not has_content and not all_tool_calls:
logger.warning(f"[Gemini] ⚠️ Empty response detected!")
# Final chunk
yield {
"id": f"chatcmpl-{time.time()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "tool_calls" if all_tool_calls else "stop"
}]
}
except Exception as e:
logger.error(f"[Gemini] stream response error: {e}", exc_info=True)
error_msg = str(e)
yield {
"error": True,
"message": error_msg,
"status_code": 500
}
def _convert_tools_to_gemini_format(self, openai_tools):
"""Convert OpenAI tool format to Gemini function declarations"""
import google.generativeai as genai
gemini_functions = []
for tool in openai_tools:
if tool.get("type") == "function":
func = tool.get("function", {})
gemini_functions.append(
genai.protos.FunctionDeclaration(
name=func.get("name"),
description=func.get("description", ""),
parameters=func.get("parameters", {})
)
)
if gemini_functions:
return [genai.protos.Tool(function_declarations=gemini_functions)]
return None
def _handle_gemini_sync_response(self, model, messages, request_params, model_name):
"""Handle synchronous Gemini API response"""
import json
response = model.generate_content(messages, **request_params)
# Extract text content and function calls
text_content = ""
tool_calls = []
if response.candidates and response.candidates[0].content:
for part in response.candidates[0].content.parts:
if hasattr(part, 'text') and part.text:
text_content += part.text
elif hasattr(part, 'function_call') and part.function_call:
# Convert Gemini function call to OpenAI format
func_call = part.function_call
tool_calls.append({
"id": f"call_{hash(func_call.name)}",
"type": "function",
"function": {
"name": func_call.name,
"arguments": json.dumps(dict(func_call.args))
}
})
# Build message in OpenAI format
message = {
"role": "assistant",
"content": text_content
}
if tool_calls:
message["tool_calls"] = tool_calls
# Format response to match OpenAI structure
formatted_response = {
"id": f"gemini_{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"message": message,
"finish_reason": "stop" if not tool_calls else "tool_calls"
}
],
"usage": {
"prompt_tokens": 0, # Gemini doesn't provide token counts in the same way
"completion_tokens": 0,
"total_tokens": 0
}
}
logger.info(f"[Gemini] call_with_tools reply, model={model_name}")
return formatted_response
def _handle_gemini_stream_response(self, model, messages, request_params, model_name):
"""Handle streaming Gemini API response"""
import json
try:
response_stream = model.generate_content(messages, stream=True, **request_params)
for chunk in response_stream:
if chunk.candidates and chunk.candidates[0].content:
for part in chunk.candidates[0].content.parts:
if hasattr(part, 'text') and part.text:
# Text content
yield {
"id": f"gemini_{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": part.text},
"finish_reason": None
}]
}
elif hasattr(part, 'function_call') and part.function_call:
# Function call
func_call = part.function_call
yield {
"id": f"gemini_{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {
"tool_calls": [{
"index": 0,
"id": f"call_{hash(func_call.name)}",
"type": "function",
"function": {
"name": func_call.name,
"arguments": json.dumps(dict(func_call.args))
}
}]
},
"finish_reason": None
}]
}
except Exception as e:
logger.error(f"[Gemini] stream response error: {e}")
yield {
"error": True,
"message": str(e),
"status_code": 500
}