mirror of
https://github.com/Zippland/Snap-Solver.git
synced 2026-02-06 16:02:01 +08:00
修复gemini接口,添加豆包接口
This commit is contained in:
@@ -4,6 +4,7 @@ from .openai import OpenAIModel
|
||||
from .deepseek import DeepSeekModel
|
||||
from .alibaba import AlibabaModel
|
||||
from .google import GoogleModel
|
||||
from .doubao import DoubaoModel
|
||||
from .factory import ModelFactory
|
||||
|
||||
__all__ = [
|
||||
@@ -13,5 +14,6 @@ __all__ = [
|
||||
'DeepSeekModel',
|
||||
'AlibabaModel',
|
||||
'GoogleModel',
|
||||
'DoubaoModel',
|
||||
'ModelFactory'
|
||||
]
|
||||
|
||||
@@ -4,12 +4,13 @@ from openai import OpenAI
|
||||
from .base import BaseModel
|
||||
|
||||
class AlibabaModel(BaseModel):
|
||||
def __init__(self, api_key: str, temperature: float = 0.7, system_prompt: str = None, language: str = None, model_name: str = None):
|
||||
def __init__(self, api_key: str, temperature: float = 0.7, system_prompt: str = None, language: str = None, model_name: str = None, api_base_url: str = None):
|
||||
# 如果没有提供模型名称,才使用默认值
|
||||
self.model_name = model_name if model_name else "QVQ-Max-2025-03-25"
|
||||
print(f"初始化阿里巴巴模型: {self.model_name}")
|
||||
# 在super().__init__之前设置model_name,这样get_default_system_prompt能使用它
|
||||
super().__init__(api_key, temperature, system_prompt, language)
|
||||
self.api_base_url = api_base_url # 存储API基础URL
|
||||
|
||||
def get_default_system_prompt(self) -> str:
|
||||
"""根据模型名称返回不同的默认系统提示词"""
|
||||
|
||||
@@ -6,9 +6,10 @@ from openai import OpenAI
|
||||
from .base import BaseModel
|
||||
|
||||
class DeepSeekModel(BaseModel):
|
||||
def __init__(self, api_key: str, temperature: float = 0.7, system_prompt: str = None, language: str = None, model_name: str = "deepseek-reasoner"):
|
||||
def __init__(self, api_key: str, temperature: float = 0.7, system_prompt: str = None, language: str = None, model_name: str = "deepseek-reasoner", api_base_url: str = None):
|
||||
super().__init__(api_key, temperature, system_prompt, language)
|
||||
self.model_name = model_name
|
||||
self.api_base_url = api_base_url # 存储API基础URL
|
||||
|
||||
def get_default_system_prompt(self) -> str:
|
||||
return """You are an expert at analyzing questions and providing detailed solutions. When presented with an image of a question:
|
||||
|
||||
312
models/doubao.py
Normal file
312
models/doubao.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
from typing import Generator, Dict, Any, Optional
|
||||
import requests
|
||||
from .base import BaseModel
|
||||
|
||||
class DoubaoModel(BaseModel):
|
||||
"""
|
||||
豆包API模型实现类
|
||||
支持字节跳动的豆包AI模型,可处理文本和图像输入
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, temperature: float = 0.7, system_prompt: str = None, language: str = None, model_name: str = None, api_base_url: str = None):
|
||||
"""
|
||||
初始化豆包模型
|
||||
|
||||
Args:
|
||||
api_key: 豆包API密钥
|
||||
temperature: 生成温度
|
||||
system_prompt: 系统提示词
|
||||
language: 首选语言
|
||||
model_name: 指定具体模型名称,如不指定则使用默认值
|
||||
api_base_url: API基础URL,用于设置自定义API端点
|
||||
"""
|
||||
super().__init__(api_key, temperature, system_prompt, language)
|
||||
self.model_name = model_name or self.get_model_identifier()
|
||||
self.base_url = api_base_url or "https://ark.cn-beijing.volces.com/api/v3"
|
||||
self.max_tokens = 4096 # 默认最大输出token数
|
||||
|
||||
def get_default_system_prompt(self) -> str:
|
||||
return """你是一个专业的问题分析专家。当看到问题图片时:
|
||||
1. 仔细阅读并理解问题
|
||||
2. 分解问题的关键组成部分
|
||||
3. 提供清晰的分步解决方案
|
||||
4. 如果相关,解释涉及的概念或理论
|
||||
5. 如果有多种方法,优先解释最有效的方法"""
|
||||
|
||||
def get_model_identifier(self) -> str:
|
||||
"""返回默认的模型标识符"""
|
||||
return "doubao-seed-1-6-250615" # Doubao-Seed-1.6
|
||||
|
||||
def get_actual_model_name(self) -> str:
|
||||
"""根据配置的模型名称返回实际的API调用标识符"""
|
||||
# 豆包API的实际模型名称映射
|
||||
model_mapping = {
|
||||
"doubao-seed-1-6-250615": "doubao-seed-1-6-250615"
|
||||
}
|
||||
|
||||
return model_mapping.get(self.model_name, "doubao-seed-1-6-250615")
|
||||
|
||||
def analyze_text(self, text: str, proxies: dict = None) -> Generator[dict, None, None]:
|
||||
"""流式生成文本响应"""
|
||||
try:
|
||||
yield {"status": "started"}
|
||||
|
||||
# 设置环境变量代理(如果提供)
|
||||
original_proxies = None
|
||||
if proxies:
|
||||
original_proxies = {
|
||||
'http_proxy': os.environ.get('http_proxy'),
|
||||
'https_proxy': os.environ.get('https_proxy')
|
||||
}
|
||||
if 'http' in proxies:
|
||||
os.environ['http_proxy'] = proxies['http']
|
||||
if 'https' in proxies:
|
||||
os.environ['https_proxy'] = proxies['https']
|
||||
|
||||
try:
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 构建消息 - 根据官方API文档,暂时不使用系统提示词
|
||||
messages = []
|
||||
|
||||
# 添加用户查询
|
||||
user_content = text
|
||||
if self.language and self.language != 'auto':
|
||||
user_content = f"请使用{self.language}回答以下问题: {text}"
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": user_content
|
||||
})
|
||||
|
||||
# 构建请求数据
|
||||
data = {
|
||||
"model": self.get_actual_model_name(),
|
||||
"messages": messages,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
# 发送流式请求
|
||||
response = requests.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=data,
|
||||
stream=True,
|
||||
proxies=proxies if proxies else None,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
raise Exception(f"HTTP {response.status_code}: {error_text}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
# 初始化响应缓冲区
|
||||
response_buffer = ""
|
||||
|
||||
# 处理流式响应
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
line = line.decode('utf-8')
|
||||
if not line.startswith('data: '):
|
||||
continue
|
||||
|
||||
line = line[6:] # 移除 'data: ' 前缀
|
||||
|
||||
if line == '[DONE]':
|
||||
break
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(line)
|
||||
choices = chunk_data.get('choices', [])
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
delta = choices[0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
|
||||
if content:
|
||||
response_buffer += content
|
||||
|
||||
# 发送响应进度
|
||||
yield {
|
||||
"status": "streaming",
|
||||
"content": response_buffer
|
||||
}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# 确保发送完整的最终内容
|
||||
yield {
|
||||
"status": "completed",
|
||||
"content": response_buffer
|
||||
}
|
||||
|
||||
finally:
|
||||
# 恢复原始代理设置
|
||||
if original_proxies:
|
||||
for key, value in original_proxies.items():
|
||||
if value is None:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
except Exception as e:
|
||||
yield {
|
||||
"status": "error",
|
||||
"error": f"豆包API错误: {str(e)}"
|
||||
}
|
||||
|
||||
def analyze_image(self, image_data: str, proxies: dict = None) -> Generator[dict, None, None]:
|
||||
"""分析图像并流式生成响应"""
|
||||
try:
|
||||
yield {"status": "started"}
|
||||
|
||||
# 设置环境变量代理(如果提供)
|
||||
original_proxies = None
|
||||
if proxies:
|
||||
original_proxies = {
|
||||
'http_proxy': os.environ.get('http_proxy'),
|
||||
'https_proxy': os.environ.get('https_proxy')
|
||||
}
|
||||
if 'http' in proxies:
|
||||
os.environ['http_proxy'] = proxies['http']
|
||||
if 'https' in proxies:
|
||||
os.environ['https_proxy'] = proxies['https']
|
||||
|
||||
try:
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 处理图像数据
|
||||
if image_data.startswith('data:image'):
|
||||
# 如果是data URI,提取base64部分
|
||||
image_data = image_data.split(',', 1)[1]
|
||||
|
||||
# 构建用户消息 - 使用豆包API官方示例格式
|
||||
# 首先检查图像数据的格式,确保是有效的图像
|
||||
image_format = "jpeg" # 默认使用jpeg
|
||||
if image_data.startswith('/9j/'): # JPEG magic number in base64
|
||||
image_format = "jpeg"
|
||||
elif image_data.startswith('iVBORw0KGgo'): # PNG magic number in base64
|
||||
image_format = "png"
|
||||
|
||||
user_content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"请使用{self.language}分析这张图片并提供详细解答。" if self.language and self.language != 'auto' else "请分析这张图片并提供详细解答?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/{image_format};base64,{image_data}"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_content
|
||||
}
|
||||
]
|
||||
|
||||
# 构建请求数据
|
||||
data = {
|
||||
"model": self.get_actual_model_name(),
|
||||
"messages": messages,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
# 发送流式请求
|
||||
response = requests.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=data,
|
||||
stream=True,
|
||||
proxies=proxies if proxies else None,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
raise Exception(f"HTTP {response.status_code}: {error_text}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
# 初始化响应缓冲区
|
||||
response_buffer = ""
|
||||
|
||||
# 处理流式响应
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
line = line.decode('utf-8')
|
||||
if not line.startswith('data: '):
|
||||
continue
|
||||
|
||||
line = line[6:] # 移除 'data: ' 前缀
|
||||
|
||||
if line == '[DONE]':
|
||||
break
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(line)
|
||||
choices = chunk_data.get('choices', [])
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
delta = choices[0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
|
||||
if content:
|
||||
response_buffer += content
|
||||
|
||||
# 发送响应进度
|
||||
yield {
|
||||
"status": "streaming",
|
||||
"content": response_buffer
|
||||
}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# 确保发送完整的最终内容
|
||||
yield {
|
||||
"status": "completed",
|
||||
"content": response_buffer
|
||||
}
|
||||
|
||||
finally:
|
||||
# 恢复原始代理设置
|
||||
if original_proxies:
|
||||
for key, value in original_proxies.items():
|
||||
if value is None:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
except Exception as e:
|
||||
yield {
|
||||
"status": "error",
|
||||
"error": f"豆包图像分析错误: {str(e)}"
|
||||
}
|
||||
@@ -114,6 +114,25 @@ class ModelFactory:
|
||||
)
|
||||
# 对于阿里巴巴模型,也需要传递正确的模型名称
|
||||
elif 'qwen' in model_name.lower() or 'qvq' in model_name.lower() or 'alibaba' in model_name.lower():
|
||||
return model_class(
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
system_prompt=system_prompt,
|
||||
language=language,
|
||||
model_name=model_name
|
||||
)
|
||||
# 对于Google模型,也需要传递正确的模型名称
|
||||
elif 'gemini' in model_name.lower() or 'google' in model_name.lower():
|
||||
return model_class(
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
system_prompt=system_prompt,
|
||||
language=language,
|
||||
model_name=model_name,
|
||||
api_base_url=api_base_url
|
||||
)
|
||||
# 对于豆包模型,也需要传递正确的模型名称
|
||||
elif 'doubao' in model_name.lower():
|
||||
return model_class(
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -46,7 +46,7 @@ class GoogleModel(BaseModel):
|
||||
|
||||
def get_model_identifier(self) -> str:
|
||||
"""返回默认的模型标识符"""
|
||||
return "gemini-2.5-pro-preview-03-25"
|
||||
return "gemini-2.0-flash" # 使用有免费配额的模型作为默认值
|
||||
|
||||
def analyze_text(self, text: str, proxies: dict = None) -> Generator[dict, None, None]:
|
||||
"""流式生成文本响应"""
|
||||
|
||||
Reference in New Issue
Block a user