mirror of
https://github.com/Zippland/Snap-Solver.git
synced 2026-02-20 01:09:46 +08:00
添加百度OCR支持,更新OCR源选择和设置界面
This commit is contained in:
177
models/baidu_ocr.py
Normal file
177
models/baidu_ocr.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
from typing import Generator, Dict, Any
|
||||
from .base import BaseModel
|
||||
|
||||
class BaiduOCRModel(BaseModel):
|
||||
"""
|
||||
百度OCR模型,用于图像文字识别
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str = None, temperature: float = 0.7, system_prompt: str = None):
|
||||
"""
|
||||
初始化百度OCR模型
|
||||
|
||||
Args:
|
||||
api_key: 百度API Key
|
||||
secret_key: 百度Secret Key(可以在api_key中用冒号分隔传入)
|
||||
temperature: 不用于OCR但保持BaseModel兼容性
|
||||
system_prompt: 不用于OCR但保持BaseModel兼容性
|
||||
|
||||
Raises:
|
||||
ValueError: 如果API密钥格式无效
|
||||
"""
|
||||
super().__init__(api_key, temperature, system_prompt)
|
||||
|
||||
# 支持两种格式:单独传递或在api_key中用冒号分隔
|
||||
if secret_key:
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
else:
|
||||
try:
|
||||
self.api_key, self.secret_key = api_key.split(':')
|
||||
except ValueError:
|
||||
raise ValueError("百度OCR API密钥必须是 'API_KEY:SECRET_KEY' 格式或单独传递secret_key参数")
|
||||
|
||||
# 百度API URLs
|
||||
self.token_url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
self.ocr_url = "https://aip.baidubce.com/rest/2.0/ocr/v1/accurate_basic"
|
||||
|
||||
# 缓存access_token
|
||||
self._access_token = None
|
||||
self._token_expires = 0
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""获取百度API的access_token"""
|
||||
# 检查是否需要刷新token(提前5分钟刷新)
|
||||
if self._access_token and time.time() < self._token_expires - 300:
|
||||
return self._access_token
|
||||
|
||||
# 请求新的access_token
|
||||
params = {
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': self.api_key,
|
||||
'client_secret': self.secret_key
|
||||
}
|
||||
|
||||
data = urllib.parse.urlencode(params).encode('utf-8')
|
||||
request = urllib.request.Request(self.token_url, data=data)
|
||||
request.add_header('Content-Type', 'application/x-www-form-urlencoded')
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(request) as response:
|
||||
result = json.loads(response.read().decode('utf-8'))
|
||||
|
||||
if 'access_token' in result:
|
||||
self._access_token = result['access_token']
|
||||
# 设置过期时间(默认30天,但我们提前刷新)
|
||||
self._token_expires = time.time() + result.get('expires_in', 2592000)
|
||||
return self._access_token
|
||||
else:
|
||||
raise Exception(f"获取access_token失败: {result.get('error_description', '未知错误')}")
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"请求access_token失败: {str(e)}")
|
||||
|
||||
def ocr_image(self, image_data: str) -> str:
|
||||
"""
|
||||
对图像进行OCR识别
|
||||
|
||||
Args:
|
||||
image_data: Base64编码的图像数据
|
||||
|
||||
Returns:
|
||||
str: 识别出的文字内容
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
|
||||
# 准备请求数据
|
||||
params = {
|
||||
'image': image_data,
|
||||
'language_type': 'auto_detect', # 自动检测语言
|
||||
'detect_direction': 'true', # 检测图像朝向
|
||||
'probability': 'false' # 不返回置信度(减少响应大小)
|
||||
}
|
||||
|
||||
data = urllib.parse.urlencode(params).encode('utf-8')
|
||||
url = f"{self.ocr_url}?access_token={access_token}"
|
||||
|
||||
request = urllib.request.Request(url, data=data)
|
||||
request.add_header('Content-Type', 'application/x-www-form-urlencoded')
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(request) as response:
|
||||
result = json.loads(response.read().decode('utf-8'))
|
||||
|
||||
if 'error_code' in result:
|
||||
raise Exception(f"百度OCR API错误: {result.get('error_msg', '未知错误')}")
|
||||
|
||||
# 提取识别的文字
|
||||
words_result = result.get('words_result', [])
|
||||
text_lines = [item['words'] for item in words_result]
|
||||
|
||||
return '\n'.join(text_lines)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"OCR识别失败: {str(e)}")
|
||||
|
||||
def extract_full_text(self, image_data: str) -> str:
|
||||
"""
|
||||
提取图像中的完整文本(与Mathpix兼容的接口)
|
||||
|
||||
Args:
|
||||
image_data: Base64编码的图像数据
|
||||
|
||||
Returns:
|
||||
str: 提取的文本内容
|
||||
"""
|
||||
return self.ocr_image(image_data)
|
||||
|
||||
def analyze_image(self, image_data: str, proxies: dict = None) -> Generator[Dict[str, Any], None, None]:
|
||||
"""
|
||||
分析图像并返回OCR结果(流式输出以保持接口一致性)
|
||||
|
||||
Args:
|
||||
image_data: Base64编码的图像数据
|
||||
proxies: 代理配置(未使用)
|
||||
|
||||
Yields:
|
||||
dict: 包含OCR结果的响应
|
||||
"""
|
||||
try:
|
||||
text = self.ocr_image(image_data)
|
||||
yield {
|
||||
'status': 'completed',
|
||||
'content': text,
|
||||
'model': 'baidu-ocr'
|
||||
}
|
||||
except Exception as e:
|
||||
yield {
|
||||
'status': 'error',
|
||||
'content': f'OCR识别失败: {str(e)}',
|
||||
'model': 'baidu-ocr'
|
||||
}
|
||||
|
||||
def analyze_text(self, text: str, proxies: dict = None) -> Generator[Dict[str, Any], None, None]:
|
||||
"""
|
||||
分析文本(OCR模型不支持文本分析)
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
proxies: 代理配置(未使用)
|
||||
|
||||
Yields:
|
||||
dict: 错误响应
|
||||
"""
|
||||
yield {
|
||||
'status': 'error',
|
||||
'content': 'OCR模型不支持文本分析功能',
|
||||
'model': 'baidu-ocr'
|
||||
}
|
||||
|
||||
def get_model_identifier(self) -> str:
|
||||
"""返回模型标识符"""
|
||||
return "baidu-ocr"
|
||||
@@ -3,7 +3,8 @@ import json
|
||||
import os
|
||||
import importlib
|
||||
from .base import BaseModel
|
||||
from .mathpix import MathpixModel # MathpixModel仍然需要直接导入,因为它是特殊工具
|
||||
from .mathpix import MathpixModel # MathpixModel需要直接导入,因为它是特殊OCR工具
|
||||
from .baidu_ocr import BaiduOCRModel # 百度OCR也是特殊OCR工具,直接导入
|
||||
|
||||
class ModelFactory:
|
||||
# 模型基本信息,包含类型和特性
|
||||
@@ -39,13 +40,25 @@ class ModelFactory:
|
||||
'description': model_info.get('description', '')
|
||||
}
|
||||
|
||||
# 添加Mathpix模型(特殊工具模型)
|
||||
# 添加特殊OCR工具模型(不在配置文件中定义)
|
||||
|
||||
# 添加Mathpix OCR工具
|
||||
cls._models['mathpix'] = {
|
||||
'class': MathpixModel,
|
||||
'is_multimodal': True,
|
||||
'is_reasoning': False,
|
||||
'display_name': 'Mathpix OCR',
|
||||
'description': '文本提取工具,适用于数学公式和文本',
|
||||
'description': '数学公式识别工具,适用于复杂数学内容',
|
||||
'is_ocr_only': True
|
||||
}
|
||||
|
||||
# 添加百度OCR工具
|
||||
cls._models['baidu-ocr'] = {
|
||||
'class': BaiduOCRModel,
|
||||
'is_multimodal': True,
|
||||
'is_reasoning': False,
|
||||
'display_name': '百度OCR',
|
||||
'description': '通用文字识别工具,支持中文识别',
|
||||
'is_ocr_only': True
|
||||
}
|
||||
|
||||
@@ -62,22 +75,36 @@ class ModelFactory:
|
||||
# 不再硬编码模型定义,而是使用空字典
|
||||
cls._models = {}
|
||||
|
||||
# 只保留Mathpix作为基础工具
|
||||
# 添加特殊OCR工具(当配置加载失败时的备用)
|
||||
try:
|
||||
# 导入MathpixModel类
|
||||
# 导入并添加Mathpix OCR工具
|
||||
from .mathpix import MathpixModel
|
||||
|
||||
# 添加Mathpix作为基础工具
|
||||
cls._models['mathpix'] = {
|
||||
'class': MathpixModel,
|
||||
'is_multimodal': True,
|
||||
'is_reasoning': False,
|
||||
'display_name': 'Mathpix OCR',
|
||||
'description': '文本提取工具,适用于数学公式和文本',
|
||||
'description': '数学公式识别工具,适用于复杂数学内容',
|
||||
'is_ocr_only': True
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"无法加载基础Mathpix工具: {str(e)}")
|
||||
print(f"无法加载Mathpix OCR工具: {str(e)}")
|
||||
|
||||
# 添加百度OCR工具
|
||||
try:
|
||||
from .baidu_ocr import BaiduOCRModel
|
||||
|
||||
cls._models['baidu-ocr'] = {
|
||||
'class': BaiduOCRModel,
|
||||
'is_multimodal': True,
|
||||
'is_reasoning': False,
|
||||
'display_name': '百度OCR',
|
||||
'description': '通用文字识别工具,支持中文识别',
|
||||
'is_ocr_only': True
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"无法加载百度OCR工具: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def create_model(cls, model_name: str, api_key: str, temperature: float = 0.7,
|
||||
@@ -148,6 +175,13 @@ class ModelFactory:
|
||||
temperature=temperature,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
# 对于百度OCR模型,传递api_key(支持API_KEY:SECRET_KEY格式)
|
||||
elif model_name == 'baidu-ocr':
|
||||
return model_class(
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
# 对于Anthropic模型,需要传递model_identifier参数
|
||||
elif 'claude' in model_name.lower() or 'anthropic' in model_name.lower():
|
||||
return model_class(
|
||||
|
||||
Reference in New Issue
Block a user