import json import os import base64 from typing import Generator, Dict, Any, Optional import requests from .base import BaseModel class DoubaoModel(BaseModel): """ 豆包API模型实现类 支持字节跳动的豆包AI模型,可处理文本和图像输入 """ def __init__(self, api_key: str, temperature: float = 0.7, system_prompt: str = None, language: str = None, model_name: str = None, api_base_url: str = None): """ 初始化豆包模型 Args: api_key: 豆包API密钥 temperature: 生成温度 system_prompt: 系统提示词 language: 首选语言 model_name: 指定具体模型名称,如不指定则使用默认值 api_base_url: API基础URL,用于设置自定义API端点 """ super().__init__(api_key, temperature, system_prompt, language) self.model_name = model_name or self.get_model_identifier() self.base_url = api_base_url or "https://ark.cn-beijing.volces.com/api/v3" self.max_tokens = 4096 # 默认最大输出token数 def get_default_system_prompt(self) -> str: return """你是一个专业的问题分析专家。当看到问题图片时: 1. 仔细阅读并理解问题 2. 分解问题的关键组成部分 3. 提供清晰的分步解决方案 4. 如果相关,解释涉及的概念或理论 5. 如果有多种方法,优先解释最有效的方法""" def get_model_identifier(self) -> str: """返回默认的模型标识符""" return "doubao-seed-1-6-250615" # Doubao-Seed-1.6 def get_actual_model_name(self) -> str: """根据配置的模型名称返回实际的API调用标识符""" # 豆包API的实际模型名称映射 model_mapping = { "doubao-seed-1-6-250615": "doubao-seed-1-6-250615" } return model_mapping.get(self.model_name, "doubao-seed-1-6-250615") def analyze_text(self, text: str, proxies: dict = None) -> Generator[dict, None, None]: """流式生成文本响应""" try: yield {"status": "started"} # 设置环境变量代理(如果提供) original_proxies = None if proxies: original_proxies = { 'http_proxy': os.environ.get('http_proxy'), 'https_proxy': os.environ.get('https_proxy') } if 'http' in proxies: os.environ['http_proxy'] = proxies['http'] if 'https' in proxies: os.environ['https_proxy'] = proxies['https'] try: # 构建请求头 headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } # 构建消息 - 根据官方API文档,暂时不使用系统提示词 messages = [] # 添加用户查询 user_content = text if self.language and self.language != 'auto': user_content = f"请使用{self.language}回答以下问题: {text}" messages.append({ "role": "user", "content": user_content }) # 构建请求数据 data = { "model": self.get_actual_model_name(), "messages": messages, "temperature": self.temperature, "max_tokens": self.max_tokens, "stream": True } # 发送流式请求 response = requests.post( f"{self.base_url}/chat/completions", headers=headers, json=data, stream=True, proxies=proxies if proxies else None, timeout=60 ) if response.status_code != 200: error_text = response.text raise Exception(f"HTTP {response.status_code}: {error_text}") response.raise_for_status() # 初始化响应缓冲区 response_buffer = "" # 处理流式响应 for line in response.iter_lines(): if not line: continue line = line.decode('utf-8') if not line.startswith('data: '): continue line = line[6:] # 移除 'data: ' 前缀 if line == '[DONE]': break try: chunk_data = json.loads(line) choices = chunk_data.get('choices', []) if choices and len(choices) > 0: delta = choices[0].get('delta', {}) content = delta.get('content', '') if content: response_buffer += content # 发送响应进度 yield { "status": "streaming", "content": response_buffer } except json.JSONDecodeError: continue # 确保发送完整的最终内容 yield { "status": "completed", "content": response_buffer } finally: # 恢复原始代理设置 if original_proxies: for key, value in original_proxies.items(): if value is None: if key in os.environ: del os.environ[key] else: os.environ[key] = value except Exception as e: yield { "status": "error", "error": f"豆包API错误: {str(e)}" } def analyze_image(self, image_data: str, proxies: dict = None) -> Generator[dict, None, None]: """分析图像并流式生成响应""" try: yield {"status": "started"} # 设置环境变量代理(如果提供) original_proxies = None if proxies: original_proxies = { 'http_proxy': os.environ.get('http_proxy'), 'https_proxy': os.environ.get('https_proxy') } if 'http' in proxies: os.environ['http_proxy'] = proxies['http'] if 'https' in proxies: os.environ['https_proxy'] = proxies['https'] try: # 构建请求头 headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } # 处理图像数据 if image_data.startswith('data:image'): # 如果是data URI,提取base64部分 image_data = image_data.split(',', 1)[1] # 构建用户消息 - 使用豆包API官方示例格式 # 首先检查图像数据的格式,确保是有效的图像 image_format = "jpeg" # 默认使用jpeg if image_data.startswith('/9j/'): # JPEG magic number in base64 image_format = "jpeg" elif image_data.startswith('iVBORw0KGgo'): # PNG magic number in base64 image_format = "png" user_content = [ { "type": "text", "text": f"请使用{self.language}分析这张图片并提供详细解答。" if self.language and self.language != 'auto' else "请分析这张图片并提供详细解答?" }, { "type": "image_url", "image_url": { "url": f"data:image/{image_format};base64,{image_data}" } } ] messages = [ { "role": "user", "content": user_content } ] # 构建请求数据 data = { "model": self.get_actual_model_name(), "messages": messages, "temperature": self.temperature, "max_tokens": self.max_tokens, "stream": True } # 发送流式请求 response = requests.post( f"{self.base_url}/chat/completions", headers=headers, json=data, stream=True, proxies=proxies if proxies else None, timeout=60 ) if response.status_code != 200: error_text = response.text raise Exception(f"HTTP {response.status_code}: {error_text}") response.raise_for_status() # 初始化响应缓冲区 response_buffer = "" # 处理流式响应 for line in response.iter_lines(): if not line: continue line = line.decode('utf-8') if not line.startswith('data: '): continue line = line[6:] # 移除 'data: ' 前缀 if line == '[DONE]': break try: chunk_data = json.loads(line) choices = chunk_data.get('choices', []) if choices and len(choices) > 0: delta = choices[0].get('delta', {}) content = delta.get('content', '') if content: response_buffer += content # 发送响应进度 yield { "status": "streaming", "content": response_buffer } except json.JSONDecodeError: continue # 确保发送完整的最终内容 yield { "status": "completed", "content": response_buffer } finally: # 恢复原始代理设置 if original_proxies: for key, value in original_proxies.items(): if value is None: if key in os.environ: del os.environ[key] else: os.environ[key] = value except Exception as e: yield { "status": "error", "error": f"豆包图像分析错误: {str(e)}" }