mirror of
https://github.com/Zippland/Snap-Solver.git
synced 2026-01-19 01:21:13 +08:00
278 lines
11 KiB
Python
278 lines
11 KiB
Python
from typing import Dict, Type, Any, Optional
|
||
import json
|
||
import os
|
||
import importlib
|
||
from .base import BaseModel
|
||
from .mathpix import MathpixModel # MathpixModel需要直接导入,因为它是特殊OCR工具
|
||
from .baidu_ocr import BaiduOCRModel # 百度OCR也是特殊OCR工具,直接导入
|
||
|
||
class ModelFactory:
|
||
# 模型基本信息,包含类型和特性
|
||
_models: Dict[str, Dict[str, Any]] = {}
|
||
_class_map: Dict[str, Type[BaseModel]] = {}
|
||
|
||
@classmethod
|
||
def initialize(cls):
|
||
"""从配置文件加载模型信息"""
|
||
try:
|
||
config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'models.json')
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
# 加载提供商信息和类映射
|
||
providers = config.get('providers', {})
|
||
for provider_id, provider_info in providers.items():
|
||
class_name = provider_info.get('class_name')
|
||
if class_name:
|
||
# 从当前包动态导入模型类
|
||
module = importlib.import_module(f'.{provider_id.lower()}', package=__package__)
|
||
cls._class_map[provider_id] = getattr(module, class_name)
|
||
|
||
# 加载模型信息
|
||
for model_id, model_info in config.get('models', {}).items():
|
||
provider_id = model_info.get('provider')
|
||
if provider_id and provider_id in cls._class_map:
|
||
cls._models[model_id] = {
|
||
'class': cls._class_map[provider_id],
|
||
'provider_id': provider_id,
|
||
'is_multimodal': model_info.get('supportsMultimodal', False),
|
||
'is_reasoning': model_info.get('isReasoning', False),
|
||
'display_name': model_info.get('name', model_id),
|
||
'description': model_info.get('description', '')
|
||
}
|
||
|
||
# 添加特殊OCR工具模型(不在配置文件中定义)
|
||
|
||
# 添加Mathpix OCR工具
|
||
cls._models['mathpix'] = {
|
||
'class': MathpixModel,
|
||
'is_multimodal': True,
|
||
'is_reasoning': False,
|
||
'display_name': 'Mathpix OCR',
|
||
'description': '数学公式识别工具,适用于复杂数学内容',
|
||
'is_ocr_only': True
|
||
}
|
||
|
||
# 添加百度OCR工具
|
||
cls._models['baidu-ocr'] = {
|
||
'class': BaiduOCRModel,
|
||
'is_multimodal': True,
|
||
'is_reasoning': False,
|
||
'display_name': '百度OCR',
|
||
'description': '通用文字识别工具,支持中文识别',
|
||
'is_ocr_only': True
|
||
}
|
||
|
||
print(f"已从配置加载 {len(cls._models)} 个模型")
|
||
except Exception as e:
|
||
print(f"加载模型配置失败: {str(e)}")
|
||
cls._initialize_defaults()
|
||
|
||
@classmethod
|
||
def _initialize_defaults(cls):
|
||
"""初始化默认模型(当配置加载失败时)"""
|
||
print("配置加载失败,使用空模型列表")
|
||
|
||
# 不再硬编码模型定义,而是使用空字典
|
||
cls._models = {}
|
||
|
||
# 添加特殊OCR工具(当配置加载失败时的备用)
|
||
try:
|
||
# 导入并添加Mathpix OCR工具
|
||
from .mathpix import MathpixModel
|
||
|
||
cls._models['mathpix'] = {
|
||
'class': MathpixModel,
|
||
'is_multimodal': True,
|
||
'is_reasoning': False,
|
||
'display_name': 'Mathpix OCR',
|
||
'description': '数学公式识别工具,适用于复杂数学内容',
|
||
'is_ocr_only': True
|
||
}
|
||
except Exception as e:
|
||
print(f"无法加载Mathpix OCR工具: {str(e)}")
|
||
|
||
# 添加百度OCR工具
|
||
try:
|
||
from .baidu_ocr import BaiduOCRModel
|
||
|
||
cls._models['baidu-ocr'] = {
|
||
'class': BaiduOCRModel,
|
||
'is_multimodal': True,
|
||
'is_reasoning': False,
|
||
'display_name': '百度OCR',
|
||
'description': '通用文字识别工具,支持中文识别',
|
||
'is_ocr_only': True
|
||
}
|
||
except Exception as e:
|
||
print(f"无法加载百度OCR工具: {str(e)}")
|
||
|
||
@classmethod
|
||
def create_model(cls, model_name: str, api_key: str, temperature: float = 0.7,
|
||
system_prompt: Optional[str] = None, language: Optional[str] = None, api_base_url: Optional[str] = None) -> BaseModel:
|
||
"""
|
||
Create a model instance based on the model name.
|
||
|
||
Args:
|
||
model_name: The identifier for the model
|
||
api_key: The API key for the model service
|
||
temperature: The temperature to use for generation
|
||
system_prompt: The system prompt to use
|
||
language: The preferred language for responses
|
||
api_base_url: The base URL for API requests
|
||
|
||
Returns:
|
||
A model instance
|
||
"""
|
||
if model_name not in cls._models:
|
||
raise ValueError(f"Unknown model: {model_name}")
|
||
|
||
model_info = cls._models[model_name]
|
||
model_class = model_info['class']
|
||
provider_id = model_info.get('provider_id')
|
||
|
||
if provider_id == 'openai':
|
||
return model_class(
|
||
api_key=api_key,
|
||
temperature=temperature,
|
||
system_prompt=system_prompt,
|
||
language=language,
|
||
api_base_url=api_base_url,
|
||
model_identifier=model_name
|
||
)
|
||
|
||
# 对于DeepSeek模型,需要传递正确的模型名称
|
||
if 'deepseek' in model_name.lower():
|
||
return model_class(
|
||
api_key=api_key,
|
||
temperature=temperature,
|
||
system_prompt=system_prompt,
|
||
language=language,
|
||
model_name=model_name,
|
||
api_base_url=api_base_url
|
||
)
|
||
# 对于阿里巴巴模型,也需要传递正确的模型名称
|
||
elif 'qwen' in model_name.lower() or 'qvq' in model_name.lower() or 'alibaba' in model_name.lower():
|
||
return model_class(
|
||
api_key=api_key,
|
||
temperature=temperature,
|
||
system_prompt=system_prompt,
|
||
language=language,
|
||
model_name=model_name
|
||
)
|
||
# 对于Google模型,也需要传递正确的模型名称
|
||
elif 'gemini' in model_name.lower() or 'google' in model_name.lower():
|
||
return model_class(
|
||
api_key=api_key,
|
||
temperature=temperature,
|
||
system_prompt=system_prompt,
|
||
language=language,
|
||
model_name=model_name,
|
||
api_base_url=api_base_url
|
||
)
|
||
# 对于豆包模型,也需要传递正确的模型名称
|
||
elif 'doubao' in model_name.lower():
|
||
return model_class(
|
||
api_key=api_key,
|
||
temperature=temperature,
|
||
system_prompt=system_prompt,
|
||
language=language,
|
||
model_name=model_name,
|
||
api_base_url=api_base_url
|
||
)
|
||
# 对于Mathpix模型,不传递language参数
|
||
elif model_name == 'mathpix':
|
||
return model_class(
|
||
api_key=api_key,
|
||
temperature=temperature,
|
||
system_prompt=system_prompt
|
||
)
|
||
# 对于百度OCR模型,传递api_key(支持API_KEY:SECRET_KEY格式)
|
||
elif model_name == 'baidu-ocr':
|
||
return model_class(
|
||
api_key=api_key,
|
||
temperature=temperature,
|
||
system_prompt=system_prompt
|
||
)
|
||
# 对于Anthropic模型,需要传递model_identifier参数
|
||
elif 'claude' in model_name.lower() or 'anthropic' in model_name.lower():
|
||
return model_class(
|
||
api_key=api_key,
|
||
temperature=temperature,
|
||
system_prompt=system_prompt,
|
||
language=language,
|
||
api_base_url=api_base_url,
|
||
model_identifier=model_name
|
||
)
|
||
else:
|
||
# 其他模型仅传递标准参数
|
||
return model_class(
|
||
api_key=api_key,
|
||
temperature=temperature,
|
||
system_prompt=system_prompt,
|
||
language=language,
|
||
api_base_url=api_base_url
|
||
)
|
||
|
||
@classmethod
|
||
def get_available_models(cls) -> list[Dict[str, Any]]:
|
||
"""Return a list of available models with their information"""
|
||
models_info = []
|
||
for model_id, info in cls._models.items():
|
||
# 跳过仅OCR工具模型
|
||
if info.get('is_ocr_only', False):
|
||
continue
|
||
|
||
models_info.append({
|
||
'id': model_id,
|
||
'display_name': info.get('display_name', model_id),
|
||
'description': info.get('description', ''),
|
||
'is_multimodal': info.get('is_multimodal', False),
|
||
'is_reasoning': info.get('is_reasoning', False)
|
||
})
|
||
return models_info
|
||
|
||
@classmethod
|
||
def get_model_ids(cls) -> list[str]:
|
||
"""Return a list of available model identifiers"""
|
||
return [model_id for model_id in cls._models.keys()
|
||
if not cls._models[model_id].get('is_ocr_only', False)]
|
||
|
||
@classmethod
|
||
def is_multimodal(cls, model_name: str) -> bool:
|
||
"""判断模型是否支持多模态输入"""
|
||
return cls._models.get(model_name, {}).get('is_multimodal', False)
|
||
|
||
@classmethod
|
||
def is_reasoning(cls, model_name: str) -> bool:
|
||
"""判断模型是否为推理模型"""
|
||
return cls._models.get(model_name, {}).get('is_reasoning', False)
|
||
|
||
@classmethod
|
||
def get_model_display_name(cls, model_name: str) -> str:
|
||
"""获取模型的显示名称"""
|
||
return cls._models.get(model_name, {}).get('display_name', model_name)
|
||
|
||
@classmethod
|
||
def register_model(cls, model_name: str, model_class: Type[BaseModel],
|
||
is_multimodal: bool = False, is_reasoning: bool = False,
|
||
display_name: Optional[str] = None, description: Optional[str] = None) -> None:
|
||
"""
|
||
Register a new model type with the factory.
|
||
|
||
Args:
|
||
model_name: The identifier for the model
|
||
model_class: The model class to register
|
||
is_multimodal: Whether the model supports image input
|
||
is_reasoning: Whether the model provides reasoning process
|
||
display_name: Human-readable name for the model
|
||
description: Description of the model
|
||
"""
|
||
cls._models[model_name] = {
|
||
'class': model_class,
|
||
'is_multimodal': is_multimodal,
|
||
'is_reasoning': is_reasoning,
|
||
'display_name': display_name or model_name,
|
||
'description': description or ''
|
||
}
|