Files
Snap-Solver/models/baidu_ocr.py

178 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import base64
import json
import time
import urllib.request
import urllib.parse
from typing import Generator, Dict, Any
from .base import BaseModel
class BaiduOCRModel(BaseModel):
"""
百度OCR模型用于图像文字识别
"""
def __init__(self, api_key: str, secret_key: str = None, temperature: float = 0.7, system_prompt: str = None):
"""
初始化百度OCR模型
Args:
api_key: 百度API Key
secret_key: 百度Secret Key可以在api_key中用冒号分隔传入
temperature: 不用于OCR但保持BaseModel兼容性
system_prompt: 不用于OCR但保持BaseModel兼容性
Raises:
ValueError: 如果API密钥格式无效
"""
super().__init__(api_key, temperature, system_prompt)
# 支持两种格式单独传递或在api_key中用冒号分隔
if secret_key:
self.api_key = api_key
self.secret_key = secret_key
else:
try:
self.api_key, self.secret_key = api_key.split(':')
except ValueError:
raise ValueError("百度OCR API密钥必须是 'API_KEY:SECRET_KEY' 格式或单独传递secret_key参数")
# 百度API URLs
self.token_url = "https://aip.baidubce.com/oauth/2.0/token"
self.ocr_url = "https://aip.baidubce.com/rest/2.0/ocr/v1/accurate_basic"
# 缓存access_token
self._access_token = None
self._token_expires = 0
def get_access_token(self) -> str:
"""获取百度API的access_token"""
# 检查是否需要刷新token提前5分钟刷新
if self._access_token and time.time() < self._token_expires - 300:
return self._access_token
# 请求新的access_token
params = {
'grant_type': 'client_credentials',
'client_id': self.api_key,
'client_secret': self.secret_key
}
data = urllib.parse.urlencode(params).encode('utf-8')
request = urllib.request.Request(self.token_url, data=data)
request.add_header('Content-Type', 'application/x-www-form-urlencoded')
try:
with urllib.request.urlopen(request) as response:
result = json.loads(response.read().decode('utf-8'))
if 'access_token' in result:
self._access_token = result['access_token']
# 设置过期时间默认30天但我们提前刷新
self._token_expires = time.time() + result.get('expires_in', 2592000)
return self._access_token
else:
raise Exception(f"获取access_token失败: {result.get('error_description', '未知错误')}")
except Exception as e:
raise Exception(f"请求access_token失败: {str(e)}")
def ocr_image(self, image_data: str) -> str:
"""
对图像进行OCR识别
Args:
image_data: Base64编码的图像数据
Returns:
str: 识别出的文字内容
"""
access_token = self.get_access_token()
# 准备请求数据
params = {
'image': image_data,
'language_type': 'auto_detect', # 自动检测语言
'detect_direction': 'true', # 检测图像朝向
'probability': 'false' # 不返回置信度(减少响应大小)
}
data = urllib.parse.urlencode(params).encode('utf-8')
url = f"{self.ocr_url}?access_token={access_token}"
request = urllib.request.Request(url, data=data)
request.add_header('Content-Type', 'application/x-www-form-urlencoded')
try:
with urllib.request.urlopen(request) as response:
result = json.loads(response.read().decode('utf-8'))
if 'error_code' in result:
raise Exception(f"百度OCR API错误: {result.get('error_msg', '未知错误')}")
# 提取识别的文字
words_result = result.get('words_result', [])
text_lines = [item['words'] for item in words_result]
return '\n'.join(text_lines)
except Exception as e:
raise Exception(f"OCR识别失败: {str(e)}")
def extract_full_text(self, image_data: str) -> str:
"""
提取图像中的完整文本与Mathpix兼容的接口
Args:
image_data: Base64编码的图像数据
Returns:
str: 提取的文本内容
"""
return self.ocr_image(image_data)
def analyze_image(self, image_data: str, proxies: dict = None) -> Generator[Dict[str, Any], None, None]:
"""
分析图像并返回OCR结果流式输出以保持接口一致性
Args:
image_data: Base64编码的图像数据
proxies: 代理配置(未使用)
Yields:
dict: 包含OCR结果的响应
"""
try:
text = self.ocr_image(image_data)
yield {
'status': 'completed',
'content': text,
'model': 'baidu-ocr'
}
except Exception as e:
yield {
'status': 'error',
'content': f'OCR识别失败: {str(e)}',
'model': 'baidu-ocr'
}
def analyze_text(self, text: str, proxies: dict = None) -> Generator[Dict[str, Any], None, None]:
"""
分析文本OCR模型不支持文本分析
Args:
text: 输入文本
proxies: 代理配置(未使用)
Yields:
dict: 错误响应
"""
yield {
'status': 'error',
'content': 'OCR模型不支持文本分析功能',
'model': 'baidu-ocr'
}
def get_model_identifier(self) -> str:
"""返回模型标识符"""
return "baidu-ocr"