diff --git a/agent/memory/__init__.py b/agent/memory/__init__.py index 4179bea..f638a9d 100644 --- a/agent/memory/__init__.py +++ b/agent/memory/__init__.py @@ -6,5 +6,6 @@ Provides long-term memory capabilities with hybrid search (vector + keyword) from agent.memory.manager import MemoryManager from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config +from agent.memory.embedding import create_embedding_provider -__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config'] +__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config', 'create_embedding_provider'] diff --git a/agent/memory/config.py b/agent/memory/config.py index 758611d..63d945b 100644 --- a/agent/memory/config.py +++ b/agent/memory/config.py @@ -41,6 +41,10 @@ class MemoryConfig: enable_auto_sync: bool = True sync_on_search: bool = True + # Memory flush config (独立于模型 context window) + flush_token_threshold: int = 50000 # 50K tokens 触发 flush + flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮) + def get_workspace(self) -> Path: """Get workspace root directory""" return Path(self.workspace_root) diff --git a/agent/memory/embedding.py b/agent/memory/embedding.py index 4a71828..509370b 100644 --- a/agent/memory/embedding.py +++ b/agent/memory/embedding.py @@ -4,20 +4,19 @@ Embedding providers for memory Supports OpenAI and local embedding models """ -from typing import List, Optional -from abc import ABC, abstractmethod import hashlib -import json +from abc import ABC, abstractmethod +from typing import List, Optional class EmbeddingProvider(ABC): """Base class for embedding providers""" - + @abstractmethod def embed(self, text: str) -> List[float]: """Generate embedding for text""" pass - + @abstractmethod def embed_batch(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for multiple texts""" @@ -31,7 +30,7 @@ class EmbeddingProvider(ABC): class OpenAIEmbeddingProvider(EmbeddingProvider): - """OpenAI embedding provider""" + """OpenAI embedding provider using REST API""" def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None): """ @@ -45,87 +44,58 @@ class OpenAIEmbeddingProvider(EmbeddingProvider): self.model = model self.api_key = api_key self.api_base = api_base or "https://api.openai.com/v1" - - # Lazy import to avoid dependency issues - try: - from openai import OpenAI - self.client = OpenAI(api_key=api_key, base_url=api_base) - except ImportError: - raise ImportError("OpenAI package not installed. Install with: pip install openai") - + + if not self.api_key: + raise ValueError("OpenAI API key is required") + # Set dimensions based on model self._dimensions = 1536 if "small" in model else 3072 - + + def _call_api(self, input_data): + """Call OpenAI embedding API using requests""" + import requests + + url = f"{self.api_base}/embeddings" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + data = { + "input": input_data, + "model": self.model + } + + response = requests.post(url, headers=headers, json=data, timeout=30) + response.raise_for_status() + return response.json() + def embed(self, text: str) -> List[float]: """Generate embedding for text""" - response = self.client.embeddings.create( - input=text, - model=self.model - ) - return response.data[0].embedding - + result = self._call_api(text) + return result["data"][0]["embedding"] + def embed_batch(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for multiple texts""" if not texts: return [] - - response = self.client.embeddings.create( - input=texts, - model=self.model - ) - return [item.embedding for item in response.data] - + + result = self._call_api(texts) + return [item["embedding"] for item in result["data"]] + @property def dimensions(self) -> int: return self._dimensions -class LocalEmbeddingProvider(EmbeddingProvider): - """Local embedding provider using sentence-transformers""" - - def __init__(self, model: str = "all-MiniLM-L6-v2"): - """ - Initialize local embedding provider - - Args: - model: Model name from sentence-transformers - """ - self.model_name = model - - try: - from sentence_transformers import SentenceTransformer - self.model = SentenceTransformer(model) - self._dimensions = self.model.get_sentence_embedding_dimension() - except ImportError: - raise ImportError( - "sentence-transformers not installed. " - "Install with: pip install sentence-transformers" - ) - - def embed(self, text: str) -> List[float]: - """Generate embedding for text""" - embedding = self.model.encode(text, convert_to_numpy=True) - return embedding.tolist() - - def embed_batch(self, texts: List[str]) -> List[List[float]]: - """Generate embeddings for multiple texts""" - if not texts: - return [] - - embeddings = self.model.encode(texts, convert_to_numpy=True) - return embeddings.tolist() - - @property - def dimensions(self) -> int: - return self._dimensions +# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search class EmbeddingCache: """Cache for embeddings to avoid recomputation""" - + def __init__(self): self.cache = {} - + def get(self, text: str, provider: str, model: str) -> Optional[List[float]]: """Get cached embedding""" key = self._compute_key(text, provider, model) @@ -156,20 +126,23 @@ def create_embedding_provider( """ Factory function to create embedding provider + Only supports OpenAI embedding via REST API. + If initialization fails, caller should fall back to keyword-only search. + Args: - provider: Provider name ("openai" or "local") - model: Model name (provider-specific) - api_key: API key for remote providers - api_base: API base URL for remote providers + provider: Provider name (only "openai" is supported) + model: Model name (default: text-embedding-3-small) + api_key: OpenAI API key (required) + api_base: API base URL (default: https://api.openai.com/v1) Returns: EmbeddingProvider instance + + Raises: + ValueError: If provider is not "openai" or api_key is missing """ - if provider == "openai": - model = model or "text-embedding-3-small" - return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base) - elif provider == "local": - model = model or "all-MiniLM-L6-v2" - return LocalEmbeddingProvider(model=model) - else: - raise ValueError(f"Unknown embedding provider: {provider}") + if provider != "openai": + raise ValueError(f"Only 'openai' provider is supported, got: {provider}") + + model = model or "text-embedding-3-small" + return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base) diff --git a/agent/memory/manager.py b/agent/memory/manager.py index b313e82..1e47811 100644 --- a/agent/memory/manager.py +++ b/agent/memory/manager.py @@ -70,8 +70,9 @@ class MemoryManager: except Exception as e: # Embedding provider failed, but that's OK # We can still use keyword search and file operations - print(f"⚠️ Warning: Embedding provider initialization failed: {e}") - print(f"ℹ️ Memory will work with keyword search only (no semantic search)") + from common.log import logger + logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}") + logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)") # Initialize memory flush manager workspace_dir = self.config.get_workspace() @@ -135,13 +136,19 @@ class MemoryManager: # Perform vector search (if embedding provider available) vector_results = [] if self.embedding_provider: - query_embedding = self.embedding_provider.embed(query) - vector_results = self.storage.search_vector( - query_embedding=query_embedding, - user_id=user_id, - scopes=scopes, - limit=max_results * 2 # Get more candidates for merging - ) + try: + from common.log import logger + query_embedding = self.embedding_provider.embed(query) + vector_results = self.storage.search_vector( + query_embedding=query_embedding, + user_id=user_id, + scopes=scopes, + limit=max_results * 2 # Get more candidates for merging + ) + logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}") + except Exception as e: + from common.log import logger + logger.warning(f"[MemoryManager] Vector search failed: {e}") # Perform keyword search keyword_results = self.storage.search_keyword( @@ -150,6 +157,8 @@ class MemoryManager: scopes=scopes, limit=max_results * 2 ) + from common.log import logger + logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}") # Merge results merged = self._merge_results( @@ -356,30 +365,30 @@ class MemoryManager: def should_flush_memory( self, - current_tokens: int, - context_window: int = 128000, - reserve_tokens: int = 20000, - soft_threshold: int = 4000 + current_tokens: int = 0 ) -> bool: """ Check if memory flush should be triggered + 独立的 flush 触发机制,不依赖模型 context window。 + 使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold + Args: current_tokens: Current session token count - context_window: Model's context window size (default: 128K) - reserve_tokens: Reserve tokens for compaction overhead (default: 20K) - soft_threshold: Trigger N tokens before threshold (default: 4K) Returns: True if memory flush should run """ return self.flush_manager.should_flush( current_tokens=current_tokens, - context_window=context_window, - reserve_tokens=reserve_tokens, - soft_threshold=soft_threshold + token_threshold=self.config.flush_token_threshold, + turn_threshold=self.config.flush_turn_threshold ) + def increment_turn(self): + """增加对话轮数计数(每次用户消息+AI回复算一轮)""" + self.flush_manager.increment_turn() + async def execute_memory_flush( self, agent_executor, diff --git a/agent/memory/storage.py b/agent/memory/storage.py index b8fccf0..373512b 100644 --- a/agent/memory/storage.py +++ b/agent/memory/storage.py @@ -46,14 +46,32 @@ class MemoryStorage: def __init__(self, db_path: Path): self.db_path = db_path self.conn: Optional[sqlite3.Connection] = None + self.fts5_available = False # Track FTS5 availability self._init_db() + def _check_fts5_support(self) -> bool: + """Check if SQLite has FTS5 support""" + try: + self.conn.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(test)") + self.conn.execute("DROP TABLE IF EXISTS fts5_test") + return True + except sqlite3.OperationalError as e: + if "no such module: fts5" in str(e): + return False + raise + def _init_db(self): """Initialize database with schema""" try: self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) self.conn.row_factory = sqlite3.Row + # Check FTS5 support + self.fts5_available = self._check_fts5_support() + if not self.fts5_available: + from common.log import logger + logger.warning("[MemoryStorage] FTS5 not available, using LIKE-based keyword search") + # Check database integrity try: result = self.conn.execute("PRAGMA integrity_check").fetchone() @@ -125,43 +143,44 @@ class MemoryStorage: ON chunks(path, hash) """) - # Create FTS5 virtual table for keyword search - # Use default unicode61 tokenizer (stable and compatible) - # For CJK support, we'll use LIKE queries as fallback - self.conn.execute(""" - CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5( - text, - id UNINDEXED, - user_id UNINDEXED, - path UNINDEXED, - source UNINDEXED, - scope UNINDEXED, - content='chunks', - content_rowid='rowid' - ) - """) - - # Create triggers to keep FTS in sync - self.conn.execute(""" - CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN - INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope) - VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope); - END - """) - - self.conn.execute(""" - CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN - DELETE FROM chunks_fts WHERE rowid = old.rowid; - END - """) - - self.conn.execute(""" - CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN - UPDATE chunks_fts SET text = new.text, id = new.id, - user_id = new.user_id, path = new.path, source = new.source, scope = new.scope - WHERE rowid = new.rowid; - END - """) + # Create FTS5 virtual table for keyword search (only if supported) + if self.fts5_available: + # Use default unicode61 tokenizer (stable and compatible) + # For CJK support, we'll use LIKE queries as fallback + self.conn.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5( + text, + id UNINDEXED, + user_id UNINDEXED, + path UNINDEXED, + source UNINDEXED, + scope UNINDEXED, + content='chunks', + content_rowid='rowid' + ) + """) + + # Create triggers to keep FTS in sync + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN + INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope) + VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope); + END + """) + + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN + DELETE FROM chunks_fts WHERE rowid = old.rowid; + END + """) + + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN + UPDATE chunks_fts SET text = new.text, id = new.id, + user_id = new.user_id, path = new.path, source = new.source, scope = new.scope + WHERE rowid = new.rowid; + END + """) # Create files metadata table self.conn.execute(""" @@ -301,21 +320,22 @@ class MemoryStorage: Keyword search using FTS5 + LIKE fallback Strategy: - 1. Try FTS5 search first (good for English and word-based languages) - 2. If no results and query contains CJK characters, use LIKE search + 1. If FTS5 available: Try FTS5 search first (good for English and word-based languages) + 2. If no FTS5 or no results and query contains CJK: Use LIKE search """ if scopes is None: scopes = ["shared"] if user_id: scopes.append("user") - # Try FTS5 search first - fts_results = self._search_fts5(query, user_id, scopes, limit) - if fts_results: - return fts_results + # Try FTS5 search first (if available) + if self.fts5_available: + fts_results = self._search_fts5(query, user_id, scopes, limit) + if fts_results: + return fts_results - # Fallback to LIKE search for CJK characters - if MemoryStorage._contains_cjk(query): + # Fallback to LIKE search (always for CJK, or if FTS5 not available) + if not self.fts5_available or MemoryStorage._contains_cjk(query): return self._search_like(query, user_id, scopes, limit) return [] diff --git a/agent/memory/summarizer.py b/agent/memory/summarizer.py index 4b102ab..46b2b59 100644 --- a/agent/memory/summarizer.py +++ b/agent/memory/summarizer.py @@ -41,46 +41,42 @@ class MemoryFlushManager: # Tracking self.last_flush_token_count: Optional[int] = None self.last_flush_timestamp: Optional[datetime] = None + self.turn_count: int = 0 # 对话轮数计数器 def should_flush( self, - current_tokens: int, - context_window: int, - reserve_tokens: int = 20000, - soft_threshold: int = 4000 + current_tokens: int = 0, + token_threshold: int = 50000, + turn_threshold: int = 20 ) -> bool: """ Determine if memory flush should be triggered - Similar to clawdbot's shouldRunMemoryFlush logic: - threshold = contextWindow - reserveTokens - softThreshold + 独立的 flush 触发机制,不依赖模型 context window: + - Token 阈值: 达到 50K tokens 时触发 + - 轮次阈值: 达到 20 轮对话时触发 Args: current_tokens: Current session token count - context_window: Model's context window size - reserve_tokens: Reserve tokens for compaction overhead - soft_threshold: Trigger flush N tokens before threshold + token_threshold: Token threshold to trigger flush (default: 50K) + turn_threshold: Turn threshold to trigger flush (default: 20) Returns: True if flush should run """ - if current_tokens <= 0: - return False + # 检查 token 阈值 + if current_tokens > 0 and current_tokens >= token_threshold: + # 避免重复 flush + if self.last_flush_token_count is not None: + if current_tokens <= self.last_flush_token_count + 5000: + return False + return True - threshold = max(0, context_window - reserve_tokens - soft_threshold) - if threshold <= 0: - return False + # 检查轮次阈值 + if self.turn_count >= turn_threshold: + return True - # Check if we've crossed the threshold - if current_tokens < threshold: - return False - - # Avoid duplicate flush in same compaction cycle - if self.last_flush_token_count is not None: - if current_tokens <= self.last_flush_token_count + soft_threshold: - return False - - return True + return False def get_today_memory_file(self, user_id: Optional[str] = None) -> Path: """ @@ -130,7 +126,12 @@ class MemoryFlushManager: f"Pre-compaction memory flush. " f"Store durable memories now (use memory/{today}.md for daily notes; " f"create memory/ if needed). " - f"If nothing to store, reply with NO_REPLY." + f"\n\n" + f"重要提示:\n" + f"- MEMORY.md: 记录最核心、最常用的信息(例如重要规则、偏好、决策、要求等)\n" + f" 如果 MEMORY.md 过长,可以精简或移除不再重要的内容。避免冗长描述,用关键词和要点形式记录\n" + f"- memory/{today}.md: 记录当天发生的事件、关键信息、经验教训、对话过程摘要等,突出重点\n" + f"- 如果没有重要内容需要记录,回复 NO_REPLY\n" ) def create_flush_system_prompt(self) -> str: @@ -142,6 +143,20 @@ class MemoryFlushManager: return ( "Pre-compaction memory flush turn. " "The session is near auto-compaction; capture durable memories to disk. " + "\n\n" + "记忆写入原则:\n" + "1. MEMORY.md 精简原则: 只记录核心信息(<2000 tokens)\n" + " - 记录重要规则、偏好、决策、要求等需要长期记住的关键信息,无需记录过多细节\n" + " - 如果 MEMORY.md 过长,可以根据需要精简或删除过时内容\n" + "\n" + "2. 天级记忆 (memory/YYYY-MM-DD.md):\n" + " - 记录当天的重要事件、关键信息、经验教训、对话过程摘要等,确保核心信息点被完整记录\n" + "\n" + "3. 判断标准:\n" + " - 这个信息未来会经常用到吗?→ MEMORY.md\n" + " - 这是今天的重要事件或决策吗?→ memory/YYYY-MM-DD.md\n" + " - 这是临时性的、不重要的内容吗?→ 不记录\n" + "\n" "You may reply, but usually NO_REPLY is correct." ) @@ -180,6 +195,7 @@ class MemoryFlushManager: # Track flush self.last_flush_token_count = current_tokens self.last_flush_timestamp = datetime.now() + self.turn_count = 0 # 重置轮数计数器 return True @@ -187,6 +203,10 @@ class MemoryFlushManager: print(f"Memory flush failed: {e}") return False + def increment_turn(self): + """增加对话轮数计数""" + self.turn_count += 1 + def get_status(self) -> dict: """Get memory flush status""" return { diff --git a/agent/prompt/builder.py b/agent/prompt/builder.py index 6d8fc80..c9ae35b 100644 --- a/agent/prompt/builder.py +++ b/agent/prompt/builder.py @@ -179,8 +179,8 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]: tool_map = {} tool_descriptions = { "read": "读取文件内容", - "write": "创建或覆盖文件", - "edit": "精确编辑文件内容", + "write": "创建新文件或完全覆盖现有文件(会删除原内容!追加内容请用 edit)。注意:单次 write 内容不要超过 10KB,超大文件请分步创建", + "edit": "精确编辑文件(追加、修改、删除部分内容)", "ls": "列出目录内容", "grep": "在文件中搜索内容", "find": "按照模式查找文件", @@ -237,11 +237,13 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]: "叙述要求: 保持简洁、信息密度高,避免重复显而易见的步骤。", "", "完成标准:", - "- 确保用户的需求得到实际解决,而不仅仅是制定计划", - "- 当任务需要多次工具调用时,持续推进直到完成", + "- 确保用户的需求得到实际解决,而不仅仅是制定计划。", + "- 当任务需要多次工具调用时,持续推进直到完成, 解决完后向用户报告结果或回复用户的问题", "- 每次工具调用后,评估是否已获得足够信息来推进或完成任务", "- 避免重复调用相同的工具和相同参数获取相同的信息,除非用户明确要求", "", + "**安全提醒**: 回复中涉及密钥、令牌、密码等敏感信息时,必须脱敏处理,禁止直接显示完整内容。", + "", ]) return lines @@ -305,17 +307,21 @@ def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], langu "", "在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:", "", - "1. 使用 `memory_search` 在 MEMORY.md 和 memory/*.md 中搜索", - "2. 然后使用 `memory_get` 只拉取需要的行", - "3. 如果搜索后仍然信心不足,告诉用户你已经检查过了", + "1. 不确定记忆文件位置 → 先用 `memory_search` 通过关键词和语义检索相关内容", + "2. 已知文件位置 → 直接用 `memory_get` 读取相应的行 (例如:MEMORY.md, memory/YYYY-MM-DD.md)", + "3. search 无结果 → 尝试用 `memory_get` 读取MEMORY.md及最近两天记忆文件", "", "**记忆文件结构**:", - "- `MEMORY.md`: 长期记忆,包含重要的背景信息", - "- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的对话和事件", + "- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)", + "- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息", "", - "**使用原则**:", - "- 自然使用记忆,就像你本来就知道", - "- 不要主动提起或列举记忆,除非用户明确询问", + "**写入记忆**:", + "- 追加内容 → `edit` 工具,oldText 留空", + "- 修改内容 → `edit` 工具,oldText 填写要替换的文本", + "- 新建文件 → `write` 工具", + "- **禁止写入敏感信息**:API密钥、令牌等敏感信息严禁写入记忆文件", + "", + "**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。", "", ] @@ -385,8 +391,8 @@ def _build_workspace_section(workspace_dir: str, language: str, is_first_convers "", "**交流规范**:", "", - "- 在所有对话中,无需提及技术细节(如 SOUL.md、USER.md 等文件名,工具名称,配置等),除非用户明确询问", - "- 用自然表达如「我已记住」而非「已更新 SOUL.md」", + "- 在对话中,非必要不输出工作空间技术细节(如 SOUL.md、USER.md等文件名称,工具名称,配置等),除非用户明确询问", + "- 例如用自然表达如「我已记住」而非「已更新 MEMORY.md」", "", ] diff --git a/agent/prompt/workspace.py b/agent/prompt/workspace.py index a096706..37d9a96 100644 --- a/agent/prompt/workspace.py +++ b/agent/prompt/workspace.py @@ -64,7 +64,7 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works _create_template_if_missing(agents_path, _get_agents_template()) _create_template_if_missing(memory_path, _get_memory_template()) - logger.info(f"[Workspace] Initialized workspace at: {workspace_dir}") + logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}") return WorkspaceFiles( soul_path=soul_path, @@ -270,14 +270,9 @@ def _get_agents_template() -> str: 2. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项) 3. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容) -**重要**: -- 爱好(唱歌、篮球等)→ MEMORY.md,不是 USER.md -- 近期计划(下周要做什么)→ MEMORY.md,不是 USER.md -- USER.md 只存放不会变的基本信息 - ## 安全 -- 永远不要泄露私人数据 +- 永远不要泄露秘钥等私人数据 - 不要在未经询问的情况下运行破坏性命令 - 当有疑问时,先问 diff --git a/agent/protocol/agent.py b/agent/protocol/agent.py index 5c4f994..1b031e0 100644 --- a/agent/protocol/agent.py +++ b/agent/protocol/agent.py @@ -1,5 +1,6 @@ import json import time +import threading from common.log import logger from agent.protocol.models import LLMRequest, LLMModel @@ -43,6 +44,7 @@ class Agent: self.output_mode = output_mode self.last_usage = None # Store last API response usage info self.messages = [] # Unified message history for stream mode + self.messages_lock = threading.Lock() # Lock for thread-safe message operations self.memory_manager = memory_manager # Memory manager for auto memory flush self.workspace_dir = workspace_dir # Workspace directory self.enable_skills = enable_skills # Skills enabled flag @@ -57,7 +59,7 @@ class Agent: try: from agent.skills import SkillManager self.skill_manager = SkillManager(workspace_dir=workspace_dir) - logger.info(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills") + logger.debug(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills") except Exception as e: logger.warning(f"Failed to initialize SkillManager: {e}") @@ -335,7 +337,8 @@ class Agent: """ # Clear history if requested if clear_history: - self.messages = [] + with self.messages_lock: + self.messages = [] # Get model to use if not self.model: @@ -344,7 +347,17 @@ class Agent: # Get full system prompt with skills full_system_prompt = self.get_full_system_prompt(skill_filter=skill_filter) - # Create stream executor with agent's message history + # Create a copy of messages for this execution to avoid concurrent modification + # Record the original length to track which messages are new + with self.messages_lock: + messages_copy = self.messages.copy() + original_length = len(self.messages) + + # Get max_context_turns from config + from config import conf + max_context_turns = conf().get("agent_max_context_turns", 30) + + # Create stream executor with copied message history executor = AgentStreamExecutor( agent=self, model=self.model, @@ -352,14 +365,21 @@ class Agent: tools=self.tools, max_turns=self.max_steps, on_event=on_event, - messages=self.messages # Pass agent's message history + messages=messages_copy, # Pass copied message history + max_context_turns=max_context_turns ) # Execute response = executor.run_stream(user_message) - # Update agent's message history from executor - self.messages = executor.messages + # Append only the NEW messages from this execution (thread-safe) + # This allows concurrent requests to both contribute to history + with self.messages_lock: + new_messages = executor.messages[original_length:] + self.messages.extend(new_messages) + + # Store executor reference for agent_bridge to access files_to_send + self.stream_executor = executor # Execute all post-process tools self._execute_post_process_tools() diff --git a/agent/protocol/agent_stream.py b/agent/protocol/agent_stream.py index 3ad6eb5..2b1c883 100644 --- a/agent/protocol/agent_stream.py +++ b/agent/protocol/agent_stream.py @@ -7,9 +7,9 @@ import json import time from typing import List, Dict, Any, Optional, Callable -from common.log import logger from agent.protocol.models import LLMRequest, LLMModel from agent.tools.base_tool import BaseTool, ToolResult +from common.log import logger class AgentStreamExecutor: @@ -31,7 +31,8 @@ class AgentStreamExecutor: tools: List[BaseTool], max_turns: int = 50, on_event: Optional[Callable] = None, - messages: Optional[List[Dict]] = None + messages: Optional[List[Dict]] = None, + max_context_turns: int = 30 ): """ Initialize stream executor @@ -44,6 +45,7 @@ class AgentStreamExecutor: max_turns: Maximum number of turns on_event: Event callback function messages: Optional existing message history (for persistent conversations) + max_context_turns: Maximum number of conversation turns to keep in context """ self.agent = agent self.model = model @@ -52,12 +54,16 @@ class AgentStreamExecutor: self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools self.max_turns = max_turns self.on_event = on_event + self.max_context_turns = max_context_turns # Message history - use provided messages or create new list self.messages = messages if messages is not None else [] # Tool failure tracking for retry protection self.tool_failure_history = [] # List of (tool_name, args_hash, success) tuples + + # Track files to send (populated by read tool) + self.files_to_send = [] # List of file metadata dicts def _emit_event(self, event_type: str, data: dict = None): """Emit event""" @@ -78,12 +84,15 @@ class AgentStreamExecutor: args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) return hashlib.md5(args_str.encode()).hexdigest()[:8] - def _check_consecutive_failures(self, tool_name: str, args: dict) -> tuple[bool, str]: + def _check_consecutive_failures(self, tool_name: str, args: dict) -> tuple[bool, str, bool]: """ Check if tool has failed too many times consecutively Returns: - (should_stop, reason) + (should_stop, reason, is_critical) + - should_stop: Whether to stop tool execution + - reason: Reason for stopping + - is_critical: Whether to abort entire conversation (True for 8+ failures) """ args_hash = self._hash_args(args) @@ -99,7 +108,7 @@ class AgentStreamExecutor: break # Different tool or args, stop counting if same_args_failures >= 3: - return True, f"Tool '{tool_name}' with same arguments failed {same_args_failures} times consecutively. Stopping to prevent infinite loop." + return True, f"工具 '{tool_name}' 使用相同参数连续失败 {same_args_failures} 次,停止执行以防止无限循环", False # Count consecutive failures for same tool (any args) same_tool_failures = 0 @@ -112,10 +121,15 @@ class AgentStreamExecutor: else: break # Different tool, stop counting - if same_tool_failures >= 6: - return True, f"Tool '{tool_name}' failed {same_tool_failures} times consecutively (with any arguments). Stopping to prevent infinite loop." + # Hard stop at 8 failures - abort with critical message + if same_tool_failures >= 8: + return True, f"抱歉,我没能完成这个任务。可能是我理解有误或者当前方法不太合适。\n\n建议你:\n• 换个方式描述需求试试\n• 把任务拆分成更小的步骤\n• 或者换个思路来解决", True - return False, "" + # Warning at 6 failures + if same_tool_failures >= 6: + return True, f"工具 '{tool_name}' 连续失败 {same_tool_failures} 次(使用不同参数),停止执行以防止无限循环", False + + return False, "", False def _record_tool_result(self, tool_name: str, args: dict, success: bool): """Record tool execution result for failure tracking""" @@ -136,10 +150,7 @@ class AgentStreamExecutor: Final response text """ # Log user message with model info - logger.info(f"{'='*50}") - logger.info(f"🤖 Model: {self.model.model}") - logger.info(f"👤 用户: {user_message}") - logger.info(f"{'='*50}") + logger.info(f"🤖 {self.model.model} | 👤 {user_message}") # Add user message (Claude format - use content blocks for consistency) self.messages.append({ @@ -160,54 +171,74 @@ class AgentStreamExecutor: try: while turn < self.max_turns: turn += 1 - logger.info(f"第 {turn} 轮") + logger.debug(f"第 {turn} 轮") self._emit_event("turn_start", {"turn": turn}) # Check if memory flush is needed (before calling LLM) + # 使用独立的 flush 阈值(50K tokens 或 20 轮) if self.agent.memory_manager and hasattr(self.agent, 'last_usage'): usage = self.agent.last_usage if usage and 'input_tokens' in usage: current_tokens = usage.get('input_tokens', 0) - context_window = self.agent._get_model_context_window() - # Use configured reserve_tokens or calculate based on context window - reserve_tokens = self.agent._get_context_reserve_tokens() - # Use smaller soft_threshold to trigger flush earlier (e.g., at 50K tokens) - soft_threshold = 10000 # Trigger 10K tokens before limit if self.agent.memory_manager.should_flush_memory( - current_tokens=current_tokens, - context_window=context_window, - reserve_tokens=reserve_tokens, - soft_threshold=soft_threshold + current_tokens=current_tokens ): self._emit_event("memory_flush_start", { "current_tokens": current_tokens, - "threshold": context_window - reserve_tokens - soft_threshold + "turn_count": self.agent.memory_manager.flush_manager.turn_count }) # TODO: Execute memory flush in background # This would require async support - logger.info(f"Memory flush recommended at {current_tokens} tokens") + logger.info( + f"Memory flush recommended: tokens={current_tokens}, turns={self.agent.memory_manager.flush_manager.turn_count}") - # Call LLM - assistant_msg, tool_calls = self._call_llm_stream() + # Call LLM (enable retry_on_empty for better reliability) + assistant_msg, tool_calls = self._call_llm_stream(retry_on_empty=True) final_response = assistant_msg # No tool calls, end loop if not tool_calls: # 检查是否返回了空响应 if not assistant_msg: - logger.warning(f"[Agent] LLM returned empty response (no content and no tool calls)") + logger.warning(f"[Agent] LLM returned empty response after retry (no content and no tool calls)") + logger.info(f"[Agent] This usually happens when LLM thinks the task is complete after tool execution") - # 生成通用的友好提示 - final_response = ( - "抱歉,我暂时无法生成回复。请尝试换一种方式描述你的需求,或稍后再试。" - ) - logger.info(f"Generated fallback response for empty LLM output") + # 如果之前有工具调用,强制要求 LLM 生成文本回复 + if turn > 1: + logger.info(f"[Agent] Requesting explicit response from LLM...") + + # 添加一条消息,明确要求回复用户 + self.messages.append({ + "role": "user", + "content": [{ + "type": "text", + "text": "请向用户说明刚才工具执行的结果或回答用户的问题。" + }] + }) + + # 再调用一次 LLM + assistant_msg, tool_calls = self._call_llm_stream(retry_on_empty=False) + final_response = assistant_msg + + # 如果还是空,才使用 fallback + if not assistant_msg and not tool_calls: + logger.warning(f"[Agent] Still empty after explicit request") + final_response = ( + "抱歉,我暂时无法生成回复。请尝试换一种方式描述你的需求,或稍后再试。" + ) + logger.info(f"Generated fallback response for empty LLM output") + else: + # 第一轮就空回复,直接 fallback + final_response = ( + "抱歉,我暂时无法生成回复。请尝试换一种方式描述你的需求,或稍后再试。" + ) + logger.info(f"Generated fallback response for empty LLM output") else: logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}") - logger.info(f"✅ 完成 (无工具调用)") + logger.debug(f"✅ 完成 (无工具调用)") self._emit_event("turn_end", { "turn": turn, "has_tool_calls": False @@ -233,6 +264,20 @@ class AgentStreamExecutor: result = self._execute_tool(tool_call) tool_results.append(result) + # Check if this is a file to send (from read tool) + if result.get("status") == "success" and isinstance(result.get("result"), dict): + result_data = result.get("result") + if result_data.get("type") == "file_to_send": + # Store file metadata for later sending + self.files_to_send.append(result_data) + logger.info(f"📎 检测到待发送文件: {result_data.get('file_name', result_data.get('path'))}") + + # Check for critical error - abort entire conversation + if result.get("status") == "critical_error": + logger.error(f"💥 检测到严重错误,终止对话") + final_response = result.get('result', '任务执行失败') + return final_response + # Log tool result in compact format status_emoji = "✅" if result.get("status") == "success" else "❌" result_data = result.get('result', '') @@ -305,11 +350,37 @@ class AgentStreamExecutor: }) if turn >= self.max_turns: - logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}") - if not final_response: + logger.warning(f"⚠️ 已达到最大决策步数限制: {self.max_turns}") + + # Force model to summarize without tool calls + logger.info(f"[Agent] Requesting summary from LLM after reaching max steps...") + + # Add a system message to force summary + self.messages.append({ + "role": "user", + "content": [{ + "type": "text", + "text": f"你已经执行了{turn}个决策步骤,达到了单次运行的最大步数限制。请总结一下你目前的执行过程和结果,告诉用户当前的进展情况。不要再调用工具,直接用文字回复。" + }] + }) + + # Call LLM one more time to get summary (without retry to avoid loops) + try: + summary_response, summary_tools = self._call_llm_stream(retry_on_empty=False) + if summary_response: + final_response = summary_response + logger.info(f"💭 Summary: {summary_response[:150]}{'...' if len(summary_response) > 150 else ''}") + else: + # Fallback if model still doesn't respond + final_response = ( + f"我已经执行了{turn}个决策步骤,达到了单次运行的步数上限。" + "任务可能还未完全完成,建议你将任务拆分成更小的步骤,或者换一种方式描述需求。" + ) + except Exception as e: + logger.warning(f"Failed to get summary from LLM: {e}") final_response = ( - "抱歉,我在处理你的请求时遇到了一些困难,尝试了多次仍未能完成。" - "请尝试简化你的问题,或换一种方式描述。" + f"我已经执行了{turn}个决策步骤,达到了单次运行的步数上限。" + "任务可能还未完全完成,建议你将任务拆分成更小的步骤,或者换一种方式描述需求。" ) except Exception as e: @@ -318,9 +389,13 @@ class AgentStreamExecutor: raise finally: - logger.info(f"🏁 完成({turn}轮)") + logger.debug(f"🏁 完成({turn}轮)") self._emit_event("agent_end", {"final_response": final_response}) + # 每轮对话结束后增加计数(用户消息+AI回复=1轮) + if self.agent.memory_manager: + self.agent.memory_manager.increment_turn() + return final_response def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3) -> tuple[str, List[Dict]]: @@ -380,6 +455,7 @@ class AgentStreamExecutor: # Streaming response full_content = "" tool_calls_buffer = {} # {index: {id, name, arguments}} + stop_reason = None # Track why the stream stopped try: stream = self.model.call_stream(request) @@ -392,21 +468,47 @@ class AgentStreamExecutor: if isinstance(error_data, dict): error_msg = error_data.get("message", chunk.get("message", "Unknown error")) error_code = error_data.get("code", "") + error_type = error_data.get("type", "") else: error_msg = chunk.get("message", str(error_data)) error_code = "" + error_type = "" status_code = chunk.get("status_code", "N/A") - logger.error(f"API Error: {error_msg} (Status: {status_code}, Code: {error_code})") - logger.error(f"Full error chunk: {chunk}") - # Raise exception with full error message for retry logic - raise Exception(f"{error_msg} (Status: {status_code})") + # Log error with all available information + logger.error(f"🔴 Stream API Error:") + logger.error(f" Message: {error_msg}") + logger.error(f" Status Code: {status_code}") + logger.error(f" Error Code: {error_code}") + logger.error(f" Error Type: {error_type}") + logger.error(f" Full chunk: {chunk}") + + # Check if this is a context overflow error (keyword-based, works for all models) + # Don't rely on specific status codes as different providers use different codes + error_msg_lower = error_msg.lower() + is_overflow = any(keyword in error_msg_lower for keyword in [ + 'context length exceeded', 'maximum context length', 'prompt is too long', + 'context overflow', 'context window', 'too large', 'exceeds model context', + 'request_too_large', 'request exceeds the maximum size', 'tokens exceed' + ]) + + if is_overflow: + # Mark as context overflow for special handling + raise Exception(f"[CONTEXT_OVERFLOW] {error_msg} (Status: {status_code})") + else: + # Raise exception with full error message for retry logic + raise Exception(f"{error_msg} (Status: {status_code}, Code: {error_code}, Type: {error_type})") # Parse chunk if isinstance(chunk, dict) and "choices" in chunk: choice = chunk["choices"][0] delta = choice.get("delta", {}) + + # Capture finish_reason if present + finish_reason = choice.get("finish_reason") + if finish_reason: + stop_reason = finish_reason # Handle text content if "content" in delta and delta["content"]: @@ -437,9 +539,46 @@ class AgentStreamExecutor: tool_calls_buffer[index]["arguments"] += func["arguments"] except Exception as e: - error_str = str(e).lower() + error_str = str(e) + error_str_lower = error_str.lower() + + # Check if error is context overflow (non-retryable, needs session reset) + # Method 1: Check for special marker (set in stream error handling above) + is_context_overflow = '[context_overflow]' in error_str_lower + + # Method 2: Fallback to keyword matching for non-stream errors + if not is_context_overflow: + is_context_overflow = any(keyword in error_str_lower for keyword in [ + 'context length exceeded', 'maximum context length', 'prompt is too long', + 'context overflow', 'context window', 'too large', 'exceeds model context', + 'request_too_large', 'request exceeds the maximum size' + ]) + + # Check if error is message format error (incomplete tool_use/tool_result pairs) + # This happens when previous conversation had tool failures + is_message_format_error = any(keyword in error_str_lower for keyword in [ + 'tool_use', 'tool_result', 'without', 'immediately after', + 'corresponding', 'must have', 'each' + ]) and 'status: 400' in error_str_lower + + if is_context_overflow or is_message_format_error: + error_type = "context overflow" if is_context_overflow else "message format error" + logger.error(f"💥 {error_type} detected: {e}") + # Clear message history to recover + logger.warning("🔄 Clearing conversation history to recover") + self.messages.clear() + # Raise special exception with user-friendly message + if is_context_overflow: + raise Exception( + "抱歉,对话历史过长导致上下文溢出。我已清空历史记录,请重新描述你的需求。" + ) + else: + raise Exception( + "抱歉,之前的对话出现了问题。我已清空历史记录,请重新发送你的消息。" + ) + # Check if error is retryable (timeout, connection, rate limit, server busy, etc.) - is_retryable = any(keyword in error_str for keyword in [ + is_retryable = any(keyword in error_str_lower for keyword in [ 'timeout', 'timed out', 'connection', 'network', 'rate limit', 'overloaded', 'unavailable', 'busy', 'retry', '429', '500', '502', '503', '504', '512' @@ -469,15 +608,19 @@ class AgentStreamExecutor: try: arguments = json.loads(tc["arguments"]) if tc["arguments"] else {} except json.JSONDecodeError as e: - logger.error(f"Failed to parse tool arguments: {tc['arguments']}") + args_preview = tc['arguments'][:200] if len(tc['arguments']) > 200 else tc['arguments'] + logger.error(f"Failed to parse tool arguments for {tc['name']}") + logger.error(f"Arguments length: {len(tc['arguments'])} chars") + logger.error(f"Arguments preview: {args_preview}...") logger.error(f"JSON decode error: {e}") + # Return a clear error message to the LLM instead of empty dict # This helps the LLM understand what went wrong tool_calls.append({ "id": tc["id"], "name": tc["name"], "arguments": {}, - "_parse_error": f"Invalid JSON in tool arguments: {tc['arguments'][:100]}... Error: {str(e)}" + "_parse_error": f"Invalid JSON in tool arguments: {args_preview}... Error: {str(e)}. Tip: For large content, consider splitting into smaller chunks or using a different approach." }) continue @@ -489,11 +632,12 @@ class AgentStreamExecutor: # Check for empty response and retry once if enabled if retry_on_empty and not full_content and not tool_calls: - logger.warning(f"⚠️ LLM returned empty response, retrying once...") + logger.warning(f"⚠️ LLM returned empty response (stop_reason: {stop_reason}), retrying once...") self._emit_event("message_end", { "content": "", "tool_calls": [], - "empty_retry": True + "empty_retry": True, + "stop_reason": stop_reason }) # Retry without retry flag to avoid infinite loop return self._call_llm_stream( @@ -560,16 +704,25 @@ class AgentStreamExecutor: return result # Check for consecutive failures (retry protection) - should_stop, stop_reason = self._check_consecutive_failures(tool_name, arguments) + should_stop, stop_reason, is_critical = self._check_consecutive_failures(tool_name, arguments) if should_stop: logger.error(f"🛑 {stop_reason}") self._record_tool_result(tool_name, arguments, False) - # 返回错误给 LLM,让它尝试其他方法 - result = { - "status": "error", - "result": f"{stop_reason}\n\nThis approach is not working. Please try a completely different method or ask the user for more information/clarification.", - "execution_time": 0 - } + + if is_critical: + # Critical failure - abort entire conversation + result = { + "status": "critical_error", + "result": stop_reason, + "execution_time": 0 + } + else: + # Normal failure - let LLM try different approach + result = { + "status": "error", + "result": f"{stop_reason}\n\n当前方法行不通,请尝试完全不同的方法或向用户询问更多信息。", + "execution_time": 0 + } return result self._emit_event("tool_execution_start", { @@ -656,52 +809,174 @@ class AgentStreamExecutor: logger.warning(f"⚠️ Removing incomplete tool_use message from history") self.messages.pop() + def _identify_complete_turns(self) -> List[Dict]: + """ + 识别完整的对话轮次 + + 一个完整轮次包括: + 1. 用户消息(text) + 2. AI 回复(可能包含 tool_use) + 3. 工具结果(tool_result,如果有) + 4. 后续 AI 回复(如果有) + + Returns: + List of turns, each turn is a dict with 'messages' list + """ + turns = [] + current_turn = {'messages': []} + + for msg in self.messages: + role = msg.get('role') + content = msg.get('content', []) + + if role == 'user': + # 检查是否是用户查询(不是工具结果) + is_user_query = False + if isinstance(content, list): + is_user_query = any( + block.get('type') == 'text' + for block in content + if isinstance(block, dict) + ) + elif isinstance(content, str): + is_user_query = True + + if is_user_query: + # 开始新轮次 + if current_turn['messages']: + turns.append(current_turn) + current_turn = {'messages': [msg]} + else: + # 工具结果,属于当前轮次 + current_turn['messages'].append(msg) + else: + # AI 回复,属于当前轮次 + current_turn['messages'].append(msg) + + # 添加最后一个轮次 + if current_turn['messages']: + turns.append(current_turn) + + return turns + + def _estimate_turn_tokens(self, turn: Dict) -> int: + """估算一个轮次的 tokens""" + return sum( + self.agent._estimate_message_tokens(msg) + for msg in turn['messages'] + ) + def _trim_messages(self): """ - Trim message history to stay within context limits. - Uses agent's context management configuration. + 智能清理消息历史,保持对话完整性 + + 使用完整轮次作为清理单位,确保: + 1. 不会在对话中间截断 + 2. 工具调用链(tool_use + tool_result)保持完整 + 3. 每轮对话都是完整的(用户消息 + AI回复 + 工具调用) """ if not self.messages or not self.agent: return - # Get context window and reserve tokens from agent + # Step 1: 识别完整轮次 + turns = self._identify_complete_turns() + + if not turns: + return + + # Step 2: 轮次限制 - 保留最近 N 轮 + if len(turns) > self.max_context_turns: + removed_turns = len(turns) - self.max_context_turns + turns = turns[-self.max_context_turns:] # 保留最近的轮次 + + logger.info( + f"💾 上下文轮次超限: {len(turns) + removed_turns} > {self.max_context_turns}," + f"移除最早的 {removed_turns} 轮完整对话" + ) + + # Step 3: Token 限制 - 保留完整轮次 + # Get context window from agent (based on model) context_window = self.agent._get_model_context_window() - reserve_tokens = self.agent._get_context_reserve_tokens() - max_tokens = context_window - reserve_tokens - # Estimate current tokens - current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages) + # Use configured max_context_tokens if available + if hasattr(self.agent, 'max_context_tokens') and self.agent.max_context_tokens: + max_tokens = self.agent.max_context_tokens + else: + # Reserve 10% for response generation + reserve_tokens = int(context_window * 0.1) + max_tokens = context_window - reserve_tokens - # Add system prompt tokens + # Estimate system prompt tokens system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt}) - current_tokens += system_tokens + available_tokens = max_tokens - system_tokens - # If under limit, no need to trim - if current_tokens <= max_tokens: + # Calculate current tokens + current_tokens = sum(self._estimate_turn_tokens(turn) for turn in turns) + + # If under limit, reconstruct messages and return + if current_tokens + system_tokens <= max_tokens: + # Reconstruct message list from turns + new_messages = [] + for turn in turns: + new_messages.extend(turn['messages']) + + old_count = len(self.messages) + self.messages = new_messages + + # Log if we removed messages due to turn limit + if old_count > len(self.messages): + logger.info(f" 重建消息列表: {old_count} -> {len(self.messages)} 条消息") return - # Keep messages from newest, accumulating tokens - available_tokens = max_tokens - system_tokens - kept_messages = [] + # Token limit exceeded - keep complete turns from newest + logger.info( + f"🔄 上下文tokens超限: ~{current_tokens + system_tokens} > {max_tokens}," + f"将按完整轮次移除最早的对话" + ) + + # 从最新轮次开始,反向累加(保持完整轮次) + kept_turns = [] accumulated_tokens = 0 - - for msg in reversed(self.messages): - msg_tokens = self.agent._estimate_message_tokens(msg) - if accumulated_tokens + msg_tokens <= available_tokens: - kept_messages.insert(0, msg) - accumulated_tokens += msg_tokens + min_turns = 3 # 尽量保留至少 3 轮,但不强制(避免超出 token 限制) + + for i, turn in enumerate(reversed(turns)): + turn_tokens = self._estimate_turn_tokens(turn) + turns_from_end = i + 1 + + # 检查是否超出限制 + if accumulated_tokens + turn_tokens <= available_tokens: + kept_turns.insert(0, turn) + accumulated_tokens += turn_tokens else: + # 超出限制 + # 如果还没有保留足够的轮次,且这是最后的机会,尝试保留 + if len(kept_turns) < min_turns and turns_from_end <= min_turns: + # 检查是否严重超出(超出 20% 以上则放弃) + overflow_ratio = (accumulated_tokens + turn_tokens - available_tokens) / available_tokens + if overflow_ratio < 0.2: # 允许最多超出 20% + kept_turns.insert(0, turn) + accumulated_tokens += turn_tokens + logger.debug(f" 为保留最少轮次,允许超出 {overflow_ratio*100:.1f}%") + continue + # 停止保留更早的轮次 break - + + # 重建消息列表 + new_messages = [] + for turn in kept_turns: + new_messages.extend(turn['messages']) + old_count = len(self.messages) - self.messages = kept_messages + old_turn_count = len(turns) + self.messages = new_messages new_count = len(self.messages) - + new_turn_count = len(kept_turns) + if old_count > new_count: logger.info( - f"Context trimmed: {old_count} -> {new_count} messages " - f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, " - f"limit: {max_tokens})" + f" 移除了 {old_turn_count - new_turn_count} 轮对话 " + f"({old_count} -> {new_count} 条消息," + f"~{current_tokens + system_tokens} -> ~{accumulated_tokens + system_tokens} tokens)" ) def _prepare_messages(self) -> List[Dict[str, Any]]: diff --git a/agent/skills/config.py b/agent/skills/config.py index bbcd9f6..21e1737 100644 --- a/agent/skills/config.py +++ b/agent/skills/config.py @@ -70,33 +70,27 @@ def should_include_skill( entry: SkillEntry, config: Optional[Dict] = None, current_platform: Optional[str] = None, - lenient: bool = True, ) -> bool: """ Determine if a skill should be included based on requirements. - Similar to clawdbot's shouldIncludeSkill logic, but with lenient mode: - - In lenient mode (default): Only check explicit disable and platform, ignore missing requirements - - In strict mode: Check all requirements (binary, env vars, config) + Simple rule: Skills are auto-enabled if their requirements are met. + - Has required API keys → enabled + - Missing API keys → disabled + - Wrong keys → enabled but will fail at runtime (LLM will handle error) :param entry: SkillEntry to check - :param config: Configuration dictionary + :param config: Configuration dictionary (currently unused, reserved for future) :param current_platform: Current platform (default: auto-detect) - :param lenient: If True, ignore missing requirements and load all skills (default: True) :return: True if skill should be included """ metadata = entry.metadata - skill_name = entry.skill.name - skill_config = get_skill_config(config, skill_name) - - # Always check if skill is explicitly disabled in config - if skill_config and skill_config.get('enabled') is False: - return False + # No metadata = always include (no requirements) if not metadata: return True - # Always check platform requirements (can't work on wrong platform) + # Check platform requirements (can't work on wrong platform) if metadata.os: platform_name = current_platform or resolve_runtime_platform() # Map common platform names @@ -114,12 +108,7 @@ def should_include_skill( if metadata.always: return True - # In lenient mode, skip requirement checks and load all skills - # Skills will fail gracefully at runtime if requirements are missing - if lenient: - return True - - # Strict mode: Check all requirements + # Check requirements if metadata.requires: # Check required binaries (all must be present) required_bins = metadata.requires.get('bins', []) @@ -133,29 +122,13 @@ def should_include_skill( if not has_any_binary(any_bins): return False - # Check environment variables (with config fallback) + # Check environment variables (API keys) + # Simple rule: All required env vars must be set required_env = metadata.requires.get('env', []) if required_env: for env_name in required_env: - # Check in order: 1) env var, 2) skill config env, 3) skill config apiKey (if primaryEnv) - if has_env_var(env_name): - continue - if skill_config: - # Check skill config env dict - skill_env = skill_config.get('env', {}) - if isinstance(skill_env, dict) and env_name in skill_env: - continue - # Check skill config apiKey (if this is the primaryEnv) - if metadata.primary_env == env_name and skill_config.get('apiKey'): - continue - # Requirement not satisfied - return False - - # Check config paths - required_config = metadata.requires.get('config', []) - if required_config and config: - for config_path in required_config: - if not is_config_path_truthy(config, config_path): + if not has_env_var(env_name): + # Missing required API key → disable skill return False return True diff --git a/agent/skills/formatter.py b/agent/skills/formatter.py index e77d7d6..7868e09 100644 --- a/agent/skills/formatter.py +++ b/agent/skills/formatter.py @@ -34,6 +34,7 @@ def format_skills_for_prompt(skills: List[Skill]) -> str: lines.append(f" {_escape_xml(skill.name)}") lines.append(f" {_escape_xml(skill.description)}") lines.append(f" {_escape_xml(skill.file_path)}") + lines.append(f" {_escape_xml(skill.base_dir)}") lines.append(" ") lines.append("") diff --git a/agent/skills/frontmatter.py b/agent/skills/frontmatter.py index 565c1f7..9905e29 100644 --- a/agent/skills/frontmatter.py +++ b/agent/skills/frontmatter.py @@ -23,7 +23,22 @@ def parse_frontmatter(content: str) -> Dict[str, Any]: frontmatter_text = match.group(1) - # Simple YAML-like parsing (supports key: value format) + # Try to use PyYAML for proper YAML parsing + try: + import yaml + frontmatter = yaml.safe_load(frontmatter_text) + if not isinstance(frontmatter, dict): + frontmatter = {} + return frontmatter + except ImportError: + # Fallback to simple parsing if PyYAML not available + pass + except Exception: + # If YAML parsing fails, fall back to simple parsing + pass + + # Simple YAML-like parsing (supports key: value format only) + # This is a fallback for when PyYAML is not available for line in frontmatter_text.split('\n'): line = line.strip() if not line or line.startswith('#'): @@ -72,10 +87,8 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]: if not isinstance(metadata_raw, dict): return None - # Support both 'moltbot' and 'cow' keys for compatibility - meta_obj = metadata_raw.get('moltbot') or metadata_raw.get('cow') - if not meta_obj or not isinstance(meta_obj, dict): - return None + # Use metadata_raw directly (COW format) + meta_obj = metadata_raw # Parse install specs install_specs = [] diff --git a/agent/skills/loader.py b/agent/skills/loader.py index 0bc8f4a..cc77d32 100644 --- a/agent/skills/loader.py +++ b/agent/skills/loader.py @@ -137,6 +137,10 @@ class SkillLoader: name = frontmatter.get('name', parent_dir_name) description = frontmatter.get('description', '') + # Special handling for linkai-agent: dynamically load apps from config.json + if name == 'linkai-agent': + description = self._load_linkai_agent_description(skill_dir, description) + if not description or not description.strip(): diagnostics.append(f"Skill {name} has no description: {file_path}") return LoadSkillsResult(skills=[], diagnostics=diagnostics) @@ -161,6 +165,45 @@ class SkillLoader: return LoadSkillsResult(skills=[skill], diagnostics=diagnostics) + def _load_linkai_agent_description(self, skill_dir: str, default_description: str) -> str: + """ + Dynamically load LinkAI agent description from config.json + + :param skill_dir: Skill directory + :param default_description: Default description from SKILL.md + :return: Dynamic description with app list + """ + import json + + config_path = os.path.join(skill_dir, "config.json") + template_path = os.path.join(skill_dir, "config.json.template") + + # Try to load config.json or fallback to template + config_file = config_path if os.path.exists(config_path) else template_path + + if not os.path.exists(config_file): + return default_description + + try: + with open(config_file, 'r', encoding='utf-8') as f: + config = json.load(f) + + apps = config.get("apps", []) + if not apps: + return default_description + + # Build dynamic description with app details + app_descriptions = "; ".join([ + f"{app['app_name']}({app['app_code']}: {app['app_description']})" + for app in apps + ]) + + return f"Call LinkAI apps/workflows. {app_descriptions}" + + except Exception as e: + logger.warning(f"[SkillLoader] Failed to load linkai-agent config: {e}") + return default_description + def load_all_skills( self, managed_dir: Optional[str] = None, @@ -216,7 +259,7 @@ class SkillLoader: for diag in all_diagnostics[:5]: # Log first 5 logger.debug(f" - {diag}") - logger.info(f"Loaded {len(skill_map)} skills from all sources") + logger.debug(f"Loaded {len(skill_map)} skills from all sources") return skill_map diff --git a/agent/skills/manager.py b/agent/skills/manager.py index 580dc0f..bf9593f 100644 --- a/agent/skills/manager.py +++ b/agent/skills/manager.py @@ -59,7 +59,7 @@ class SkillManager: extra_dirs=self.extra_dirs, ) - logger.info(f"SkillManager: Loaded {len(self.skills)} skills") + logger.debug(f"SkillManager: Loaded {len(self.skills)} skills") def get_skill(self, name: str) -> Optional[SkillEntry]: """ @@ -82,32 +82,24 @@ class SkillManager: self, skill_filter: Optional[List[str]] = None, include_disabled: bool = False, - check_requirements: bool = False, # Changed default to False for lenient loading - lenient: bool = True, # New parameter for lenient mode ) -> List[SkillEntry]: """ Filter skills based on criteria. - By default (lenient=True), all skills are loaded regardless of missing requirements. - Skills will fail gracefully at runtime if requirements are not met. + Simple rule: Skills are auto-enabled if requirements are met. + - Has required API keys → included + - Missing API keys → excluded :param skill_filter: List of skill names to include (None = all) :param include_disabled: Whether to include skills with disable_model_invocation=True - :param check_requirements: Whether to check skill requirements (default: False) - :param lenient: If True, ignore missing requirements (default: True) :return: Filtered list of skill entries """ from agent.skills.config import should_include_skill entries = list(self.skills.values()) - # Check requirements (platform, explicit disable, etc.) - # In lenient mode, only checks platform and explicit disable - if check_requirements or not lenient: - entries = [e for e in entries if should_include_skill(e, self.config, lenient=lenient)] - else: - # Lenient mode: only check explicit disable and platform - entries = [e for e in entries if should_include_skill(e, self.config, lenient=True)] + # Check requirements (platform, binaries, env vars) + entries = [e for e in entries if should_include_skill(e, self.config)] # Apply skill filter if skill_filter is not None: diff --git a/agent/tools/__init__.py b/agent/tools/__init__.py index 76f7e2e..3cba117 100644 --- a/agent/tools/__init__.py +++ b/agent/tools/__init__.py @@ -2,55 +2,57 @@ from agent.tools.base_tool import BaseTool from agent.tools.tool_manager import ToolManager -# Import basic tools (no external dependencies) -from agent.tools.calculator.calculator import Calculator - # Import file operation tools from agent.tools.read.read import Read from agent.tools.write.write import Write from agent.tools.edit.edit import Edit from agent.tools.bash.bash import Bash -from agent.tools.grep.grep import Grep -from agent.tools.find.find import Find from agent.tools.ls.ls import Ls +from agent.tools.send.send import Send # Import memory tools from agent.tools.memory.memory_search import MemorySearchTool from agent.tools.memory.memory_get import MemoryGetTool -# Import web tools -from agent.tools.web_fetch.web_fetch import WebFetch - # Import tools with optional dependencies def _import_optional_tools(): """Import tools that have optional dependencies""" + from common.log import logger tools = {} - # Google Search (requires requests) + # EnvConfig Tool (requires python-dotenv) try: - from agent.tools.google_search.google_search import GoogleSearch - tools['GoogleSearch'] = GoogleSearch - except ImportError: - pass + from agent.tools.env_config.env_config import EnvConfig + tools['EnvConfig'] = EnvConfig + except ImportError as e: + logger.error( + f"[Tools] EnvConfig tool not loaded - missing dependency: {e}\n" + f" To enable environment variable management, run:\n" + f" pip install python-dotenv>=1.0.0" + ) + except Exception as e: + logger.error(f"[Tools] EnvConfig tool failed to load: {e}") - # File Save (may have dependencies) + # Scheduler Tool (requires croniter) try: - from agent.tools.file_save.file_save import FileSave - tools['FileSave'] = FileSave - except ImportError: - pass + from agent.tools.scheduler.scheduler_tool import SchedulerTool + tools['SchedulerTool'] = SchedulerTool + except ImportError as e: + logger.error( + f"[Tools] Scheduler tool not loaded - missing dependency: {e}\n" + f" To enable scheduled tasks, run:\n" + f" pip install croniter>=2.0.0" + ) + except Exception as e: + logger.error(f"[Tools] Scheduler tool failed to load: {e}") - # Terminal (basic, should work) - try: - from agent.tools.terminal.terminal import Terminal - tools['Terminal'] = Terminal - except ImportError: - pass return tools # Load optional tools _optional_tools = _import_optional_tools() +EnvConfig = _optional_tools.get('EnvConfig') +SchedulerTool = _optional_tools.get('SchedulerTool') GoogleSearch = _optional_tools.get('GoogleSearch') FileSave = _optional_tools.get('FileSave') Terminal = _optional_tools.get('Terminal') @@ -74,28 +76,24 @@ def _import_browser_tool(): # Dynamically set BrowserTool -BrowserTool = _import_browser_tool() +# BrowserTool = _import_browser_tool() # Export all tools (including optional ones that might be None) __all__ = [ 'BaseTool', 'ToolManager', - 'Calculator', 'Read', 'Write', 'Edit', 'Bash', - 'Grep', - 'Find', 'Ls', + 'Send', 'MemorySearchTool', 'MemoryGetTool', - 'WebFetch', + 'EnvConfig', + 'SchedulerTool', # Optional tools (may be None if dependencies not available) - 'GoogleSearch', - 'FileSave', - 'Terminal', - 'BrowserTool' + # 'BrowserTool' ] """ diff --git a/agent/tools/bash/bash.py b/agent/tools/bash/bash.py index e9b6ca0..4d7e564 100644 --- a/agent/tools/bash/bash.py +++ b/agent/tools/bash/bash.py @@ -3,12 +3,14 @@ Bash tool - Execute bash commands """ import os +import sys import subprocess import tempfile from typing import Dict, Any from agent.tools.base_tool import BaseTool, ToolResult from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES +from common.log import logger class Bash(BaseTool): @@ -60,6 +62,12 @@ IMPORTANT SAFETY GUIDELINES: if not command: return ToolResult.fail("Error: command parameter is required") + # Security check: Prevent accessing sensitive config files + if "~/.cow/.env" in command or "~/.cow" in command: + return ToolResult.fail( + "Error: Access denied. API keys and credentials must be accessed through the env_config tool only." + ) + # Optional safety check - only warn about extremely dangerous commands if self.safety_mode: warning = self._get_safety_warning(command) @@ -68,7 +76,31 @@ IMPORTANT SAFETY GUIDELINES: f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.") try: - # Execute command + # Prepare environment with .env file variables + env = os.environ.copy() + + # Load environment variables from ~/.cow/.env if it exists + env_file = os.path.expanduser("~/.cow/.env") + if os.path.exists(env_file): + try: + from dotenv import dotenv_values + env_vars = dotenv_values(env_file) + env.update(env_vars) + logger.debug(f"[Bash] Loaded {len(env_vars)} variables from {env_file}") + except ImportError: + logger.debug("[Bash] python-dotenv not installed, skipping .env loading") + except Exception as e: + logger.debug(f"[Bash] Failed to load .env: {e}") + + # Debug logging + logger.debug(f"[Bash] CWD: {self.cwd}") + logger.debug(f"[Bash] Command: {command[:500]}") + logger.debug(f"[Bash] OPENAI_API_KEY in env: {'OPENAI_API_KEY' in env}") + logger.debug(f"[Bash] SHELL: {env.get('SHELL', 'not set')}") + logger.debug(f"[Bash] Python executable: {sys.executable}") + logger.debug(f"[Bash] Process UID: {os.getuid()}") + + # Execute command with inherited environment variables result = subprocess.run( command, shell=True, @@ -76,8 +108,50 @@ IMPORTANT SAFETY GUIDELINES: stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - timeout=timeout + timeout=timeout, + env=env ) + + logger.debug(f"[Bash] Exit code: {result.returncode}") + logger.debug(f"[Bash] Stdout length: {len(result.stdout)}") + logger.debug(f"[Bash] Stderr length: {len(result.stderr)}") + + # Workaround for exit code 126 with no output + if result.returncode == 126 and not result.stdout and not result.stderr: + logger.warning(f"[Bash] Exit 126 with no output - trying alternative execution method") + # Try using argument list instead of shell=True + import shlex + try: + parts = shlex.split(command) + if len(parts) > 0: + logger.info(f"[Bash] Retrying with argument list: {parts[:3]}...") + retry_result = subprocess.run( + parts, + cwd=self.cwd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=timeout, + env=env + ) + logger.debug(f"[Bash] Retry exit code: {retry_result.returncode}, stdout: {len(retry_result.stdout)}, stderr: {len(retry_result.stderr)}") + + # If retry succeeded, use retry result + if retry_result.returncode == 0 or retry_result.stdout or retry_result.stderr: + result = retry_result + else: + # Both attempts failed - check if this is openai-image-vision skill + if 'openai-image-vision' in command or 'vision.sh' in command: + # Create a mock result with helpful error message + from types import SimpleNamespace + result = SimpleNamespace( + returncode=1, + stdout='{"error": "图片无法解析", "reason": "该图片格式可能不受支持,或图片文件存在问题", "suggestion": "请尝试其他图片"}', + stderr='' + ) + logger.info(f"[Bash] Converted exit 126 to user-friendly image error message for vision skill") + except Exception as retry_err: + logger.warning(f"[Bash] Retry failed: {retry_err}") # Combine stdout and stderr output = result.stdout diff --git a/agent/tools/calculator/calculator.py b/agent/tools/calculator/calculator.py deleted file mode 100644 index 092343d..0000000 --- a/agent/tools/calculator/calculator.py +++ /dev/null @@ -1,58 +0,0 @@ -import math - -from agent.tools.base_tool import BaseTool, ToolResult - - -class Calculator(BaseTool): - name: str = "calculator" - description: str = "A tool to perform basic mathematical calculations." - params: dict = { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). " - "Ensure your input is a valid Python expression, it will be evaluated directly." - } - }, - "required": ["expression"] - } - config: dict = {} - - def execute(self, args: dict) -> ToolResult: - try: - # Get the expression - expression = args["expression"] - - # Create a safe local environment containing only basic math functions - safe_locals = { - "abs": abs, - "round": round, - "max": max, - "min": min, - "pow": pow, - "sqrt": math.sqrt, - "sin": math.sin, - "cos": math.cos, - "tan": math.tan, - "pi": math.pi, - "e": math.e, - "log": math.log, - "log10": math.log10, - "exp": math.exp, - "floor": math.floor, - "ceil": math.ceil - } - - # Safely evaluate the expression - result = eval(expression, {"__builtins__": {}}, safe_locals) - - return ToolResult.success({ - "result": result, - "expression": expression - }) - except Exception as e: - return ToolResult.success({ - "error": str(e), - "expression": args.get("expression", "") - }) diff --git a/agent/tools/edit/edit.py b/agent/tools/edit/edit.py index 566309b..e17e624 100644 --- a/agent/tools/edit/edit.py +++ b/agent/tools/edit/edit.py @@ -22,7 +22,7 @@ class Edit(BaseTool): """Tool for precise file editing""" name: str = "edit" - description: str = "Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits." + description: str = "Edit a file by replacing exact text, or append to end if oldText is empty. For append: use empty oldText. For replace: oldText must match exactly (including whitespace)." params: dict = { "type": "object", @@ -33,7 +33,7 @@ class Edit(BaseTool): }, "oldText": { "type": "string", - "description": "Exact text to find and replace (must match exactly)" + "description": "Text to find and replace. Use empty string to append to end of file. For replacement: must match exactly including whitespace." }, "newText": { "type": "string", @@ -89,34 +89,45 @@ class Edit(BaseTool): normalized_old_text = normalize_to_lf(old_text) normalized_new_text = normalize_to_lf(new_text) - # Use fuzzy matching to find old text (try exact match first, then fuzzy match) - match_result = fuzzy_find_text(normalized_content, normalized_old_text) - - if not match_result.found: - return ToolResult.fail( - f"Error: Could not find the exact text in {path}. " - "The old text must match exactly including all whitespace and newlines." + # Special case: empty oldText means append to end of file + if not old_text or not old_text.strip(): + # Append mode: add newText to the end + # Add newline before newText if file doesn't end with one + if normalized_content and not normalized_content.endswith('\n'): + new_content = normalized_content + '\n' + normalized_new_text + else: + new_content = normalized_content + normalized_new_text + base_content = normalized_content # For verification + else: + # Normal edit mode: find and replace + # Use fuzzy matching to find old text (try exact match first, then fuzzy match) + match_result = fuzzy_find_text(normalized_content, normalized_old_text) + + if not match_result.found: + return ToolResult.fail( + f"Error: Could not find the exact text in {path}. " + "The old text must match exactly including all whitespace and newlines." + ) + + # Calculate occurrence count (use fuzzy normalized content for consistency) + fuzzy_content = normalize_for_fuzzy_match(normalized_content) + fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text) + occurrences = fuzzy_content.count(fuzzy_old_text) + + if occurrences > 1: + return ToolResult.fail( + f"Error: Found {occurrences} occurrences of the text in {path}. " + "The text must be unique. Please provide more context to make it unique." + ) + + # Execute replacement (use matched text position) + base_content = match_result.content_for_replacement + new_content = ( + base_content[:match_result.index] + + normalized_new_text + + base_content[match_result.index + match_result.match_length:] ) - # Calculate occurrence count (use fuzzy normalized content for consistency) - fuzzy_content = normalize_for_fuzzy_match(normalized_content) - fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text) - occurrences = fuzzy_content.count(fuzzy_old_text) - - if occurrences > 1: - return ToolResult.fail( - f"Error: Found {occurrences} occurrences of the text in {path}. " - "The text must be unique. Please provide more context to make it unique." - ) - - # Execute replacement (use matched text position) - base_content = match_result.content_for_replacement - new_content = ( - base_content[:match_result.index] + - normalized_new_text + - base_content[match_result.index + match_result.match_length:] - ) - # Verify replacement actually changed content if base_content == new_content: return ToolResult.fail( diff --git a/agent/tools/env_config/__init__.py b/agent/tools/env_config/__init__.py new file mode 100644 index 0000000..2e5822f --- /dev/null +++ b/agent/tools/env_config/__init__.py @@ -0,0 +1,3 @@ +from agent.tools.env_config.env_config import EnvConfig + +__all__ = ['EnvConfig'] diff --git a/agent/tools/env_config/env_config.py b/agent/tools/env_config/env_config.py new file mode 100644 index 0000000..f0a10fe --- /dev/null +++ b/agent/tools/env_config/env_config.py @@ -0,0 +1,284 @@ +""" +Environment Configuration Tool - Manage API keys and environment variables +""" + +import os +import re +from typing import Dict, Any +from pathlib import Path + +from agent.tools.base_tool import BaseTool, ToolResult +from common.log import logger + + +# API Key 知识库:常见的环境变量及其描述 +API_KEY_REGISTRY = { + # AI 模型服务 + "OPENAI_API_KEY": "OpenAI API 密钥 (用于GPT模型、Embedding模型)", + "GEMINI_API_KEY": "Google Gemini API 密钥", + "CLAUDE_API_KEY": "Claude API 密钥 (用于Claude模型)", + "LINKAI_API_KEY": "LinkAI智能体平台 API 密钥,支持多种模型切换", + # 搜索服务 + "BOCHA_API_KEY": "博查 AI 搜索 API 密钥 ", +} + +class EnvConfig(BaseTool): + """Tool for managing environment variables (API keys, etc.)""" + + name: str = "env_config" + description: str = ( + "Manage API keys and skill configurations securely. " + "Use this tool when user wants to configure API keys (like BOCHA_API_KEY, OPENAI_API_KEY), " + "view configured keys, or manage skill settings. " + "Actions: 'set' (add/update key), 'get' (view specific key), 'list' (show all configured keys), 'delete' (remove key). " + "Values are automatically masked for security. Changes take effect immediately via hot reload." + ) + + params: dict = { + "type": "object", + "properties": { + "action": { + "type": "string", + "description": "Action to perform: 'set', 'get', 'list', 'delete'", + "enum": ["set", "get", "list", "delete"] + }, + "key": { + "type": "string", + "description": ( + "Environment variable key name. Common keys:\n" + "- OPENAI_API_KEY: OpenAI API (GPT models)\n" + "- OPENAI_API_BASE: OpenAI API base URL\n" + "- CLAUDE_API_KEY: Anthropic Claude API\n" + "- GEMINI_API_KEY: Google Gemini API\n" + "- LINKAI_API_KEY: LinkAI platform\n" + "- BOCHA_API_KEY: Bocha AI search (博查搜索)\n" + "Use exact key names (case-sensitive, all uppercase with underscores)" + ) + }, + "value": { + "type": "string", + "description": "Value to set for the environment variable (for 'set' action)" + } + }, + "required": ["action"] + } + + def __init__(self, config: dict = None): + self.config = config or {} + # Store env config in ~/.cow directory (outside workspace for security) + self.env_dir = os.path.expanduser("~/.cow") + self.env_path = os.path.join(self.env_dir, '.env') + self.agent_bridge = self.config.get("agent_bridge") # Reference to AgentBridge for hot reload + # Don't create .env file in __init__ to avoid issues during tool discovery + # It will be created on first use in execute() + + def _ensure_env_file(self): + """Ensure the .env file exists""" + # Create ~/.cow directory if it doesn't exist + os.makedirs(self.env_dir, exist_ok=True) + + if not os.path.exists(self.env_path): + Path(self.env_path).touch() + logger.info(f"[EnvConfig] Created .env file at {self.env_path}") + + def _mask_value(self, value: str) -> str: + """Mask sensitive parts of a value for logging""" + if not value or len(value) <= 10: + return "***" + return f"{value[:6]}***{value[-4:]}" + + def _read_env_file(self) -> Dict[str, str]: + """Read all key-value pairs from .env file""" + env_vars = {} + if os.path.exists(self.env_path): + with open(self.env_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + # Skip empty lines and comments + if not line or line.startswith('#'): + continue + # Parse KEY=VALUE + match = re.match(r'^([^=]+)=(.*)$', line) + if match: + key, value = match.groups() + env_vars[key.strip()] = value.strip() + return env_vars + + def _write_env_file(self, env_vars: Dict[str, str]): + """Write all key-value pairs to .env file""" + with open(self.env_path, 'w', encoding='utf-8') as f: + f.write("# Environment variables for agent skills\n") + f.write("# Auto-managed by env_config tool\n\n") + for key, value in sorted(env_vars.items()): + f.write(f"{key}={value}\n") + + def _reload_env(self): + """Reload environment variables from .env file""" + env_vars = self._read_env_file() + for key, value in env_vars.items(): + os.environ[key] = value + logger.debug(f"[EnvConfig] Reloaded {len(env_vars)} environment variables") + + def _refresh_skills(self): + """Refresh skills after environment variable changes""" + if self.agent_bridge: + try: + # Reload .env file + self._reload_env() + + # Refresh skills in all agent instances + refreshed = self.agent_bridge.refresh_all_skills() + logger.info(f"[EnvConfig] Refreshed skills in {refreshed} agent instance(s)") + return True + except Exception as e: + logger.warning(f"[EnvConfig] Failed to refresh skills: {e}") + return False + return False + + def execute(self, args: Dict[str, Any]) -> ToolResult: + """ + Execute environment configuration operation + + :param args: Contains action, key, and value parameters + :return: Result of the operation + """ + # Ensure .env file exists on first use + self._ensure_env_file() + + action = args.get("action") + key = args.get("key") + value = args.get("value") + + try: + if action == "set": + if not key or not value: + return ToolResult.fail("Error: 'key' and 'value' are required for 'set' action.") + + # Read current env vars + env_vars = self._read_env_file() + + # Update the key + env_vars[key] = value + + # Write back to file + self._write_env_file(env_vars) + + # Update current process env + os.environ[key] = value + + logger.info(f"[EnvConfig] Set {key}={self._mask_value(value)}") + + # Try to refresh skills immediately + refreshed = self._refresh_skills() + + result = { + "message": f"Successfully set {key}", + "key": key, + "value": self._mask_value(value), + } + + if refreshed: + result["note"] = "✅ Skills refreshed automatically - changes are now active" + else: + result["note"] = "⚠️ Skills not refreshed - restart agent to load new skills" + + return ToolResult.success(result) + + elif action == "get": + if not key: + return ToolResult.fail("Error: 'key' is required for 'get' action.") + + # Check in file first, then in current env + env_vars = self._read_env_file() + value = env_vars.get(key) or os.getenv(key) + + # Get description from registry + description = API_KEY_REGISTRY.get(key, "未知用途的环境变量") + + if value is not None: + logger.info(f"[EnvConfig] Got {key}={self._mask_value(value)}") + return ToolResult.success({ + "key": key, + "value": self._mask_value(value), + "description": description, + "exists": True + }) + else: + return ToolResult.success({ + "key": key, + "description": description, + "exists": False, + "message": f"Environment variable '{key}' is not set" + }) + + elif action == "list": + env_vars = self._read_env_file() + + # Build detailed variable list with descriptions + variables_with_info = {} + for key, value in env_vars.items(): + variables_with_info[key] = { + "value": self._mask_value(value), + "description": API_KEY_REGISTRY.get(key, "未知用途的环境变量") + } + + logger.info(f"[EnvConfig] Listed {len(env_vars)} environment variables") + + if not env_vars: + return ToolResult.success({ + "message": "No environment variables configured", + "variables": {}, + "note": "常用的 API 密钥可以通过 env_config(action='set', key='KEY_NAME', value='your-key') 来配置" + }) + + return ToolResult.success({ + "message": f"Found {len(env_vars)} environment variable(s)", + "variables": variables_with_info + }) + + elif action == "delete": + if not key: + return ToolResult.fail("Error: 'key' is required for 'delete' action.") + + # Read current env vars + env_vars = self._read_env_file() + + if key not in env_vars: + return ToolResult.success({ + "message": f"Environment variable '{key}' was not set", + "key": key + }) + + # Remove the key + del env_vars[key] + + # Write back to file + self._write_env_file(env_vars) + + # Remove from current process env + if key in os.environ: + del os.environ[key] + + logger.info(f"[EnvConfig] Deleted {key}") + + # Try to refresh skills immediately + refreshed = self._refresh_skills() + + result = { + "message": f"Successfully deleted {key}", + "key": key, + } + + if refreshed: + result["note"] = "✅ Skills refreshed automatically - changes are now active" + else: + result["note"] = "⚠️ Skills not refreshed - restart agent to apply changes" + + return ToolResult.success(result) + + else: + return ToolResult.fail(f"Error: Unknown action '{action}'. Use 'set', 'get', 'list', or 'delete'.") + + except Exception as e: + logger.error(f"[EnvConfig] Error: {e}", exc_info=True) + return ToolResult.fail(f"EnvConfig tool error: {str(e)}") diff --git a/agent/tools/find/__init__.py b/agent/tools/find/__init__.py deleted file mode 100644 index f2af14f..0000000 --- a/agent/tools/find/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .find import Find - -__all__ = ['Find'] diff --git a/agent/tools/find/find.py b/agent/tools/find/find.py deleted file mode 100644 index 7a2c4a1..0000000 --- a/agent/tools/find/find.py +++ /dev/null @@ -1,177 +0,0 @@ -""" -Find tool - Search for files by glob pattern -""" - -import os -import glob as glob_module -from typing import Dict, Any, List - -from agent.tools.base_tool import BaseTool, ToolResult -from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES - - -DEFAULT_LIMIT = 1000 - - -class Find(BaseTool): - """Tool for finding files by pattern""" - - name: str = "find" - description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)." - - params: dict = { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'" - }, - "path": { - "type": "string", - "description": "Directory to search in (default: current directory)" - }, - "limit": { - "type": "integer", - "description": f"Maximum number of results (default: {DEFAULT_LIMIT})" - } - }, - "required": ["pattern"] - } - - def __init__(self, config: dict = None): - self.config = config or {} - self.cwd = self.config.get("cwd", os.getcwd()) - - def execute(self, args: Dict[str, Any]) -> ToolResult: - """ - Execute file search - - :param args: Search parameters - :return: Search results or error - """ - pattern = args.get("pattern", "").strip() - search_path = args.get("path", ".").strip() - limit = args.get("limit", DEFAULT_LIMIT) - - if not pattern: - return ToolResult.fail("Error: pattern parameter is required") - - # Resolve search path - absolute_path = self._resolve_path(search_path) - - if not os.path.exists(absolute_path): - return ToolResult.fail(f"Error: Path not found: {search_path}") - - if not os.path.isdir(absolute_path): - return ToolResult.fail(f"Error: Not a directory: {search_path}") - - try: - # Load .gitignore patterns - ignore_patterns = self._load_gitignore(absolute_path) - - # Search for files - results = [] - search_pattern = os.path.join(absolute_path, pattern) - - # Use glob with recursive support - for file_path in glob_module.glob(search_pattern, recursive=True): - # Skip if matches ignore patterns - if self._should_ignore(file_path, absolute_path, ignore_patterns): - continue - - # Get relative path - relative_path = os.path.relpath(file_path, absolute_path) - - # Add trailing slash for directories - if os.path.isdir(file_path): - relative_path += '/' - - results.append(relative_path) - - if len(results) >= limit: - break - - if not results: - return ToolResult.success({"message": "No files found matching pattern", "files": []}) - - # Sort results - results.sort() - - # Format output - raw_output = '\n'.join(results) - truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes - - output = truncation.content - details = {} - notices = [] - - result_limit_reached = len(results) >= limit - if result_limit_reached: - notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern") - details["result_limit_reached"] = limit - - if truncation.truncated: - notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached") - details["truncation"] = truncation.to_dict() - - if notices: - output += f"\n\n[{'. '.join(notices)}]" - - return ToolResult.success({ - "output": output, - "file_count": len(results), - "details": details if details else None - }) - - except Exception as e: - return ToolResult.fail(f"Error executing find: {str(e)}") - - def _resolve_path(self, path: str) -> str: - """Resolve path to absolute path""" - # Expand ~ to user home directory - path = os.path.expanduser(path) - if os.path.isabs(path): - return path - return os.path.abspath(os.path.join(self.cwd, path)) - - def _load_gitignore(self, directory: str) -> List[str]: - """Load .gitignore patterns from directory""" - patterns = [] - gitignore_path = os.path.join(directory, '.gitignore') - - if os.path.exists(gitignore_path): - try: - with open(gitignore_path, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - if line and not line.startswith('#'): - patterns.append(line) - except: - pass - - # Add common ignore patterns - patterns.extend([ - '.git', - '__pycache__', - '*.pyc', - 'node_modules', - '.DS_Store' - ]) - - return patterns - - def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool: - """Check if file should be ignored based on patterns""" - relative_path = os.path.relpath(file_path, base_path) - - for pattern in patterns: - # Simple pattern matching - if pattern in relative_path: - return True - - # Check if it's a directory pattern - if pattern.endswith('/'): - if relative_path.startswith(pattern.rstrip('/')): - return True - - return False diff --git a/agent/tools/grep/__init__.py b/agent/tools/grep/__init__.py deleted file mode 100644 index e4d57b0..0000000 --- a/agent/tools/grep/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .grep import Grep - -__all__ = ['Grep'] diff --git a/agent/tools/grep/grep.py b/agent/tools/grep/grep.py deleted file mode 100644 index 1e7d95e..0000000 --- a/agent/tools/grep/grep.py +++ /dev/null @@ -1,248 +0,0 @@ -""" -Grep tool - Search file contents for patterns -Uses ripgrep (rg) for fast searching -""" - -import os -import re -import subprocess -import json -from typing import Dict, Any, List, Optional - -from agent.tools.base_tool import BaseTool, ToolResult -from agent.tools.utils.truncate import ( - truncate_head, truncate_line, format_size, - DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH -) - - -DEFAULT_LIMIT = 100 - - -class Grep(BaseTool): - """Tool for searching file contents""" - - name: str = "grep" - description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars." - - params: dict = { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Search pattern (regex or literal string)" - }, - "path": { - "type": "string", - "description": "Directory or file to search (default: current directory)" - }, - "glob": { - "type": "string", - "description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'" - }, - "ignoreCase": { - "type": "boolean", - "description": "Case-insensitive search (default: false)" - }, - "literal": { - "type": "boolean", - "description": "Treat pattern as literal string instead of regex (default: false)" - }, - "context": { - "type": "integer", - "description": "Number of lines to show before and after each match (default: 0)" - }, - "limit": { - "type": "integer", - "description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})" - } - }, - "required": ["pattern"] - } - - def __init__(self, config: dict = None): - self.config = config or {} - self.cwd = self.config.get("cwd", os.getcwd()) - self.rg_path = self._find_ripgrep() - - def _find_ripgrep(self) -> Optional[str]: - """Find ripgrep executable""" - try: - result = subprocess.run(['which', 'rg'], capture_output=True, text=True) - if result.returncode == 0: - return result.stdout.strip() - except: - pass - return None - - def execute(self, args: Dict[str, Any]) -> ToolResult: - """ - Execute grep search - - :param args: Search parameters - :return: Search results or error - """ - if not self.rg_path: - return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.") - - pattern = args.get("pattern", "").strip() - search_path = args.get("path", ".").strip() - glob = args.get("glob") - ignore_case = args.get("ignoreCase", False) - literal = args.get("literal", False) - context = args.get("context", 0) - limit = args.get("limit", DEFAULT_LIMIT) - - if not pattern: - return ToolResult.fail("Error: pattern parameter is required") - - # Resolve search path - absolute_path = self._resolve_path(search_path) - - if not os.path.exists(absolute_path): - return ToolResult.fail(f"Error: Path not found: {search_path}") - - # Build ripgrep command - cmd = [ - self.rg_path, - '--json', - '--line-number', - '--color=never', - '--hidden' - ] - - if ignore_case: - cmd.append('--ignore-case') - - if literal: - cmd.append('--fixed-strings') - - if glob: - cmd.extend(['--glob', glob]) - - cmd.extend([pattern, absolute_path]) - - try: - # Execute ripgrep - result = subprocess.run( - cmd, - cwd=self.cwd, - capture_output=True, - text=True, - timeout=30 - ) - - # Parse JSON output - matches = [] - match_count = 0 - - for line in result.stdout.splitlines(): - if not line.strip(): - continue - - try: - event = json.loads(line) - if event.get('type') == 'match': - data = event.get('data', {}) - file_path = data.get('path', {}).get('text') - line_number = data.get('line_number') - - if file_path and line_number: - matches.append({ - 'file': file_path, - 'line': line_number - }) - match_count += 1 - - if match_count >= limit: - break - except json.JSONDecodeError: - continue - - if match_count == 0: - return ToolResult.success({"message": "No matches found", "matches": []}) - - # Format output with context - output_lines = [] - lines_truncated = False - is_directory = os.path.isdir(absolute_path) - - for match in matches: - file_path = match['file'] - line_number = match['line'] - - # Format file path - if is_directory: - relative_path = os.path.relpath(file_path, absolute_path) - else: - relative_path = os.path.basename(file_path) - - # Read file and get context - try: - with open(file_path, 'r', encoding='utf-8') as f: - file_lines = f.read().split('\n') - - # Calculate context range - start = max(0, line_number - 1 - context) if context > 0 else line_number - 1 - end = min(len(file_lines), line_number + context) if context > 0 else line_number - - # Format lines with context - for i in range(start, end): - line_text = file_lines[i].replace('\r', '') - - # Truncate long lines - truncated_text, was_truncated = truncate_line(line_text) - if was_truncated: - lines_truncated = True - - # Format output - current_line = i + 1 - if current_line == line_number: - output_lines.append(f"{relative_path}:{current_line}: {truncated_text}") - else: - output_lines.append(f"{relative_path}-{current_line}- {truncated_text}") - - except Exception: - output_lines.append(f"{relative_path}:{line_number}: (unable to read file)") - - # Apply byte truncation - raw_output = '\n'.join(output_lines) - truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes - - output = truncation.content - details = {} - notices = [] - - if match_count >= limit: - notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern") - details["match_limit_reached"] = limit - - if truncation.truncated: - notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached") - details["truncation"] = truncation.to_dict() - - if lines_truncated: - notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines") - details["lines_truncated"] = True - - if notices: - output += f"\n\n[{'. '.join(notices)}]" - - return ToolResult.success({ - "output": output, - "match_count": match_count, - "details": details if details else None - }) - - except subprocess.TimeoutExpired: - return ToolResult.fail("Error: Search timed out after 30 seconds") - except Exception as e: - return ToolResult.fail(f"Error executing grep: {str(e)}") - - def _resolve_path(self, path: str) -> str: - """Resolve path to absolute path""" - # Expand ~ to user home directory - path = os.path.expanduser(path) - if os.path.isabs(path): - return path - return os.path.abspath(os.path.join(self.cwd, path)) diff --git a/agent/tools/ls/ls.py b/agent/tools/ls/ls.py index d3e5330..d6517b3 100644 --- a/agent/tools/ls/ls.py +++ b/agent/tools/ls/ls.py @@ -50,6 +50,13 @@ class Ls(BaseTool): # Resolve path absolute_path = self._resolve_path(path) + # Security check: Prevent accessing sensitive config directory + env_config_dir = os.path.expanduser("~/.cow") + if os.path.abspath(absolute_path) == os.path.abspath(env_config_dir): + return ToolResult.fail( + "Error: Access denied. API keys and credentials must be accessed through the env_config tool only." + ) + if not os.path.exists(absolute_path): # Provide helpful hint if using relative path if not os.path.isabs(path) and not path.startswith('~'): diff --git a/agent/tools/memory/memory_get.py b/agent/tools/memory/memory_get.py index d828386..5febb10 100644 --- a/agent/tools/memory/memory_get.py +++ b/agent/tools/memory/memory_get.py @@ -4,8 +4,6 @@ Memory get tool Allows agents to read specific sections from memory files """ -from typing import Dict, Any -from pathlib import Path from agent.tools.base_tool import BaseTool @@ -22,7 +20,7 @@ class MemoryGetTool(BaseTool): "properties": { "path": { "type": "string", - "description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')" + "description": "Relative path to the memory file (e.g. 'MEMORY.md', 'memory/2026-01-01.md')" }, "start_line": { "type": "integer", @@ -70,7 +68,8 @@ class MemoryGetTool(BaseTool): workspace_dir = self.memory_manager.config.get_workspace() # Auto-prepend memory/ if not present and not absolute path - if not path.startswith('memory/') and not path.startswith('/'): + # Exception: MEMORY.md is in the root directory + if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md': path = f'memory/{path}' file_path = workspace_dir / path diff --git a/agent/tools/read/read.py b/agent/tools/read/read.py index 4810890..f88bc50 100644 --- a/agent/tools/read/read.py +++ b/agent/tools/read/read.py @@ -15,7 +15,7 @@ class Read(BaseTool): """Tool for reading file contents""" name: str = "read" - description: str = f"Read the contents of a file. Supports text files, PDF files, and images (jpg, png, gif, webp). For text files, output is truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Use offset/limit for large files." + description: str = f"Read or inspect file contents. For text/PDF files, returns content (truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB). For images/videos/audio, returns metadata only (file info, size, type). Use offset/limit for large text files." params: dict = { "type": "object", @@ -26,7 +26,7 @@ class Read(BaseTool): }, "offset": { "type": "integer", - "description": "Line number to start reading from (1-indexed, optional)" + "description": "Line number to start reading from (1-indexed, optional). Use negative values to read from end (e.g. -20 for last 20 lines)" }, "limit": { "type": "integer", @@ -39,10 +39,25 @@ class Read(BaseTool): def __init__(self, config: dict = None): self.config = config or {} self.cwd = self.config.get("cwd", os.getcwd()) - # Supported image formats - self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp'} - # Supported PDF format + + # File type categories + self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'} + self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'} + self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'} + self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'} + self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'} self.pdf_extensions = {'.pdf'} + + # Readable text formats (will be read with truncation) + self.text_extensions = { + '.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml', + '.py', '.js', '.ts', '.java', '.c', '.cpp', '.h', '.hpp', '.go', '.rs', '.rb', '.php', + '.html', '.css', '.scss', '.sass', '.less', '.vue', '.jsx', '.tsx', + '.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd', + '.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex', + '.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg', + '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx' # Office documents + } def execute(self, args: Dict[str, Any]) -> ToolResult: """ @@ -61,6 +76,13 @@ class Read(BaseTool): # Resolve path absolute_path = self._resolve_path(path) + # Security check: Prevent reading sensitive config files + env_config_path = os.path.expanduser("~/.cow/.env") + if os.path.abspath(absolute_path) == os.path.abspath(env_config_path): + return ToolResult.fail( + "Error: Access denied. API keys and credentials must be accessed through the env_config tool only." + ) + # Check if file exists if not os.path.exists(absolute_path): # Provide helpful hint if using relative path @@ -78,16 +100,25 @@ class Read(BaseTool): # Check file type file_ext = Path(absolute_path).suffix.lower() + file_size = os.path.getsize(absolute_path) - # Check if image + # Check if image - return metadata for sending if file_ext in self.image_extensions: return self._read_image(absolute_path, file_ext) + # Check if video/audio/binary/archive - return metadata only + if file_ext in self.video_extensions: + return self._return_file_metadata(absolute_path, "video", file_size) + if file_ext in self.audio_extensions: + return self._return_file_metadata(absolute_path, "audio", file_size) + if file_ext in self.binary_extensions or file_ext in self.archive_extensions: + return self._return_file_metadata(absolute_path, "binary", file_size) + # Check if PDF if file_ext in self.pdf_extensions: return self._read_pdf(absolute_path, path, offset, limit) - # Read text file + # Read text file (with truncation for large files) return self._read_text(absolute_path, path, offset, limit) def _resolve_path(self, path: str) -> str: @@ -103,25 +134,56 @@ class Read(BaseTool): return path return os.path.abspath(os.path.join(self.cwd, path)) + def _return_file_metadata(self, absolute_path: str, file_type: str, file_size: int) -> ToolResult: + """ + Return file metadata for non-readable files (video, audio, binary, etc.) + + :param absolute_path: Absolute path to the file + :param file_type: Type of file (video, audio, binary, etc.) + :param file_size: File size in bytes + :return: File metadata + """ + file_name = Path(absolute_path).name + file_ext = Path(absolute_path).suffix.lower() + + # Determine MIME type + mime_types = { + # Video + '.mp4': 'video/mp4', '.avi': 'video/x-msvideo', '.mov': 'video/quicktime', + '.mkv': 'video/x-matroska', '.webm': 'video/webm', + # Audio + '.mp3': 'audio/mpeg', '.wav': 'audio/wav', '.ogg': 'audio/ogg', + '.m4a': 'audio/mp4', '.flac': 'audio/flac', + # Binary + '.zip': 'application/zip', '.tar': 'application/x-tar', + '.gz': 'application/gzip', '.rar': 'application/x-rar-compressed', + } + mime_type = mime_types.get(file_ext, 'application/octet-stream') + + result = { + "type": f"{file_type}_metadata", + "file_type": file_type, + "path": absolute_path, + "file_name": file_name, + "mime_type": mime_type, + "size": file_size, + "size_formatted": format_size(file_size), + "message": f"{file_type.capitalize()} 文件: {file_name} ({format_size(file_size)})\n提示: 如果需要发送此文件,请使用 send 工具。" + } + + return ToolResult.success(result) + def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult: """ - Read image file + Read image file - always return metadata only (images should be sent, not read into context) :param absolute_path: Absolute path to the image file :param file_ext: File extension - :return: Result containing image information + :return: Result containing image metadata for sending """ try: - # Read image file - with open(absolute_path, 'rb') as f: - image_data = f.read() - # Get file size - file_size = len(image_data) - - # Return image information (actual image data can be base64 encoded when needed) - import base64 - base64_data = base64.b64encode(image_data).decode('utf-8') + file_size = os.path.getsize(absolute_path) # Determine MIME type mime_type_map = { @@ -133,12 +195,15 @@ class Read(BaseTool): } mime_type = mime_type_map.get(file_ext, 'image/jpeg') + # Return metadata for images (NOT file_to_send - use send tool to actually send) result = { - "type": "image", + "type": "image_metadata", + "file_type": "image", + "path": absolute_path, "mime_type": mime_type, "size": file_size, "size_formatted": format_size(file_size), - "data": base64_data # Base64 encoded image data + "message": f"图片文件: {Path(absolute_path).name} ({format_size(file_size)})\n提示: 如果需要发送此图片,请使用 send 工具。" } return ToolResult.success(result) @@ -157,21 +222,49 @@ class Read(BaseTool): :return: File content or error message """ try: + # Check file size first + file_size = os.path.getsize(absolute_path) + MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB + + if file_size > MAX_FILE_SIZE: + # File too large, return metadata only + return ToolResult.success({ + "type": "file_to_send", + "file_type": "document", + "path": absolute_path, + "size": file_size, + "size_formatted": format_size(file_size), + "message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}" + }) + # Read file with open(absolute_path, 'r', encoding='utf-8') as f: content = f.read() + # Truncate content if too long (20K characters max for model context) + MAX_CONTENT_CHARS = 20 * 1024 # 20K characters + content_truncated = False + if len(content) > MAX_CONTENT_CHARS: + content = content[:MAX_CONTENT_CHARS] + content_truncated = True + all_lines = content.split('\n') total_file_lines = len(all_lines) # Apply offset (if specified) start_line = 0 if offset is not None: - start_line = max(0, offset - 1) # Convert to 0-indexed - if start_line >= total_file_lines: - return ToolResult.fail( - f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)" - ) + if offset < 0: + # Negative offset: read from end + # -20 means "last 20 lines" → start from (total - 20) + start_line = max(0, total_file_lines + offset) + else: + # Positive offset: read from start (1-indexed) + start_line = max(0, offset - 1) # Convert to 0-indexed + if start_line >= total_file_lines: + return ToolResult.fail( + f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)" + ) start_line_display = start_line + 1 # For display (1-indexed) @@ -191,6 +284,10 @@ class Read(BaseTool): output_text = "" details = {} + # Add truncation warning if content was truncated + if content_truncated: + output_text = f"[文件内容已截断到前 {format_size(MAX_CONTENT_CHARS)},完整文件大小: {format_size(file_size)}]\n\n" + if truncation.first_line_exceeds_limit: # First line exceeds 30KB limit first_line_size = format_size(len(all_lines[start_line].encode('utf-8'))) diff --git a/agent/tools/scheduler/README.md b/agent/tools/scheduler/README.md new file mode 100644 index 0000000..55be2f9 --- /dev/null +++ b/agent/tools/scheduler/README.md @@ -0,0 +1,287 @@ +# 定时任务工具 (Scheduler Tool) + +## 功能简介 + +定时任务工具允许 Agent 创建、管理和执行定时任务,支持: + +- ⏰ **定时提醒**: 在指定时间发送消息 +- 🔄 **周期性任务**: 按固定间隔或 cron 表达式重复执行 +- 🔧 **动态工具调用**: 定时执行其他工具并发送结果(如搜索新闻、查询天气等) +- 📋 **任务管理**: 查询、启用、禁用、删除任务 + +## 安装依赖 + +```bash +pip install croniter>=2.0.0 +``` + +## 使用方法 + +### 1. 创建定时任务 + +Agent 可以通过自然语言创建定时任务,支持两种类型: + +#### 1.1 静态消息任务 + +发送预定义的消息: + +**示例对话:** +``` +用户: 每天早上9点提醒我开会 +Agent: [调用 scheduler 工具] + action: create + name: 每日开会提醒 + message: 该开会了! + schedule_type: cron + schedule_value: 0 9 * * * +``` + +#### 1.2 动态工具调用任务 + +定时执行工具并发送结果: + +**示例对话:** +``` +用户: 每天早上8点帮我读取一下今日日程 +Agent: [调用 scheduler 工具] + action: create + name: 每日日程 + tool_call: + tool_name: read + tool_params: + file_path: ~/cow/schedule.txt + result_prefix: 📅 今日日程 + schedule_type: cron + schedule_value: 0 8 * * * +``` + +**工具调用参数说明:** +- `tool_name`: 要调用的工具名称(如 `bash`、`read`、`write` 等内置工具) +- `tool_params`: 工具的参数(字典格式) +- `result_prefix`: 可选,在结果前添加的前缀文本 + +**注意:** 如果要使用 skills(如 bocha-search),需要通过 `bash` 工具调用 skill 脚本 + +### 2. 支持的调度类型 + +#### Cron 表达式 (`cron`) +使用标准 cron 表达式: + +``` +0 9 * * * # 每天 9:00 +0 */2 * * * # 每 2 小时 +30 8 * * 1-5 # 工作日 8:30 +0 0 1 * * # 每月 1 号 +``` + +#### 固定间隔 (`interval`) +以秒为单位的间隔: + +``` +3600 # 每小时 +86400 # 每天 +1800 # 每 30 分钟 +``` + +#### 一次性任务 (`once`) +指定具体时间(ISO 格式): + +``` +2024-12-25T09:00:00 +2024-12-31T23:59:59 +``` + +### 3. 查询任务列表 + +``` +用户: 查看我的定时任务 +Agent: [调用 scheduler 工具] + action: list +``` + +### 4. 查看任务详情 + +``` +用户: 查看任务 abc123 的详情 +Agent: [调用 scheduler 工具] + action: get + task_id: abc123 +``` + +### 5. 删除任务 + +``` +用户: 删除任务 abc123 +Agent: [调用 scheduler 工具] + action: delete + task_id: abc123 +``` + +### 6. 启用/禁用任务 + +``` +用户: 暂停任务 abc123 +Agent: [调用 scheduler 工具] + action: disable + task_id: abc123 + +用户: 恢复任务 abc123 +Agent: [调用 scheduler 工具] + action: enable + task_id: abc123 +``` + +## 任务存储 + +任务保存在 JSON 文件中: +``` +~/cow/scheduler/tasks.json +``` + +任务数据结构: + +**静态消息任务:** +```json +{ + "id": "abc123", + "name": "每日提醒", + "enabled": true, + "created_at": "2024-01-01T10:00:00", + "updated_at": "2024-01-01T10:00:00", + "schedule": { + "type": "cron", + "expression": "0 9 * * *" + }, + "action": { + "type": "send_message", + "content": "该开会了!", + "receiver": "wxid_xxx", + "receiver_name": "张三", + "is_group": false, + "channel_type": "wechat" + }, + "next_run_at": "2024-01-02T09:00:00", + "last_run_at": "2024-01-01T09:00:00" +} +``` + +**动态工具调用任务:** +```json +{ + "id": "def456", + "name": "每日日程", + "enabled": true, + "created_at": "2024-01-01T10:00:00", + "updated_at": "2024-01-01T10:00:00", + "schedule": { + "type": "cron", + "expression": "0 8 * * *" + }, + "action": { + "type": "tool_call", + "tool_name": "read", + "tool_params": { + "file_path": "~/cow/schedule.txt" + }, + "result_prefix": "📅 今日日程", + "receiver": "wxid_xxx", + "receiver_name": "张三", + "is_group": false, + "channel_type": "wechat" + }, + "next_run_at": "2024-01-02T08:00:00" +} +``` + +## 后台服务 + +定时任务由后台服务 `SchedulerService` 管理: + +- 每 30 秒检查一次到期任务 +- 自动执行到期任务 +- 计算下次执行时间 +- 记录执行历史和错误 + +服务在 Agent 初始化时自动启动,无需手动配置。 + +## 接收者确定 + +定时任务会发送给**创建任务时的对话对象**: + +- 如果在私聊中创建,发送给该用户 +- 如果在群聊中创建,发送到该群 +- 接收者信息在创建时自动保存 + +## 常见用例 + +### 1. 每日提醒(静态消息) +``` +用户: 每天早上8点提醒我吃药 +Agent: ✅ 定时任务创建成功 + 任务ID: a1b2c3d4 + 调度: 每天 8:00 + 消息: 该吃药了! +``` + +### 2. 工作日提醒(静态消息) +``` +用户: 工作日下午6点提醒我下班 +Agent: [创建 cron: 0 18 * * 1-5] + 消息: 该下班了! +``` + +### 3. 倒计时提醒(静态消息) +``` +用户: 1小时后提醒我 +Agent: [创建 interval: 3600] +``` + +### 4. 每日日程推送(动态工具调用) +``` +用户: 每天早上8点帮我读取今日日程 +Agent: ✅ 定时任务创建成功 + 任务ID: schedule001 + 调度: 每天 8:00 + 工具: read(file_path='~/cow/schedule.txt') + 前缀: 📅 今日日程 +``` + +### 5. 定时文件备份(动态工具调用) +``` +用户: 每天晚上11点备份工作文件 +Agent: [创建 cron: 0 23 * * *] + 工具: bash(command='cp ~/cow/work.txt ~/cow/backup/work_$(date +%Y%m%d).txt') + 前缀: ✅ 文件已备份 +``` + +### 6. 周报提醒(静态消息) +``` +用户: 每周五下午5点提醒我写周报 +Agent: [创建 cron: 0 17 * * 5] + 消息: 📊 该写周报了! +``` + +### 4. 特定日期提醒 +``` +用户: 12月25日早上9点提醒我圣诞快乐 +Agent: [创建 once: 2024-12-25T09:00:00] +``` + +## 注意事项 + +1. **时区**: 使用系统本地时区 +2. **精度**: 检查间隔为 30 秒,实际执行可能有 ±30 秒误差 +3. **持久化**: 任务保存在文件中,重启后自动恢复 +4. **一次性任务**: 执行后自动禁用,不会删除(可手动删除) +5. **错误处理**: 执行失败会记录错误,不影响其他任务 + +## 技术实现 + +- **TaskStore**: 任务持久化存储 +- **SchedulerService**: 后台调度服务 +- **SchedulerTool**: Agent 工具接口 +- **Integration**: 与 AgentBridge 集成 + +## 依赖 + +- `croniter`: Cron 表达式解析(轻量级,仅 ~50KB) diff --git a/agent/tools/scheduler/__init__.py b/agent/tools/scheduler/__init__.py new file mode 100644 index 0000000..dafc8b6 --- /dev/null +++ b/agent/tools/scheduler/__init__.py @@ -0,0 +1,7 @@ +""" +Scheduler tool for managing scheduled tasks +""" + +from .scheduler_tool import SchedulerTool + +__all__ = ["SchedulerTool"] diff --git a/agent/tools/scheduler/integration.py b/agent/tools/scheduler/integration.py new file mode 100644 index 0000000..1882195 --- /dev/null +++ b/agent/tools/scheduler/integration.py @@ -0,0 +1,447 @@ +""" +Integration module for scheduler with AgentBridge +""" + +import os +from typing import Optional +from config import conf +from common.log import logger +from bridge.context import Context, ContextType +from bridge.reply import Reply, ReplyType + +# Global scheduler service instance +_scheduler_service = None +_task_store = None + + +def init_scheduler(agent_bridge) -> bool: + """ + Initialize scheduler service + + Args: + agent_bridge: AgentBridge instance + + Returns: + True if initialized successfully + """ + global _scheduler_service, _task_store + + try: + from agent.tools.scheduler.task_store import TaskStore + from agent.tools.scheduler.scheduler_service import SchedulerService + + # Get workspace from config + workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow")) + store_path = os.path.join(workspace_root, "scheduler", "tasks.json") + + # Create task store + _task_store = TaskStore(store_path) + logger.debug(f"[Scheduler] Task store initialized: {store_path}") + + # Create execute callback + def execute_task_callback(task: dict): + """Callback to execute a scheduled task""" + try: + action = task.get("action", {}) + action_type = action.get("type") + + if action_type == "agent_task": + _execute_agent_task(task, agent_bridge) + elif action_type == "send_message": + # Legacy support for old tasks + _execute_send_message(task, agent_bridge) + elif action_type == "tool_call": + # Legacy support for old tasks + _execute_tool_call(task, agent_bridge) + elif action_type == "skill_call": + # Legacy support for old tasks + _execute_skill_call(task, agent_bridge) + else: + logger.warning(f"[Scheduler] Unknown action type: {action_type}") + except Exception as e: + logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}") + + # Create scheduler service + _scheduler_service = SchedulerService(_task_store, execute_task_callback) + _scheduler_service.start() + + logger.debug("[Scheduler] Scheduler service initialized and started") + return True + + except Exception as e: + logger.error(f"[Scheduler] Failed to initialize scheduler: {e}") + return False + + +def get_task_store(): + """Get the global task store instance""" + return _task_store + + +def get_scheduler_service(): + """Get the global scheduler service instance""" + return _scheduler_service + + +def _execute_agent_task(task: dict, agent_bridge): + """ + Execute an agent_task action - let Agent handle the task + + Args: + task: Task dictionary + agent_bridge: AgentBridge instance + """ + try: + action = task.get("action", {}) + task_description = action.get("task_description") + receiver = action.get("receiver") + is_group = action.get("is_group", False) + channel_type = action.get("channel_type", "unknown") + + if not task_description: + logger.error(f"[Scheduler] Task {task['id']}: No task_description specified") + return + + if not receiver: + logger.error(f"[Scheduler] Task {task['id']}: No receiver specified") + return + + # Check for unsupported channels + if channel_type == "dingtalk": + logger.warning(f"[Scheduler] Task {task['id']}: DingTalk channel does not support scheduled messages (Stream mode limitation). Task will execute but message cannot be sent.") + + logger.info(f"[Scheduler] Task {task['id']}: Executing agent task '{task_description}'") + + # Create context for Agent + context = Context(ContextType.TEXT, task_description) + context["receiver"] = receiver + context["isgroup"] = is_group + context["session_id"] = receiver + + # Channel-specific setup + if channel_type == "web": + import uuid + request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}" + context["request_id"] = request_id + elif channel_type == "feishu": + context["receive_id_type"] = "chat_id" if is_group else "open_id" + context["msg"] = None + elif channel_type == "dingtalk": + # DingTalk requires msg object, set to None for scheduled tasks + context["msg"] = None + # 如果是单聊,需要传递 sender_staff_id + if not is_group: + sender_staff_id = action.get("dingtalk_sender_staff_id") + if sender_staff_id: + context["dingtalk_sender_staff_id"] = sender_staff_id + + # Use Agent to execute the task + # Mark this as a scheduled task execution to prevent recursive task creation + context["is_scheduled_task"] = True + + try: + reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=True) + + if reply and reply.content: + # Send the reply via channel + from channel.channel_factory import create_channel + + try: + channel = create_channel(channel_type) + if channel: + # For web channel, register request_id + if channel_type == "web" and hasattr(channel, 'request_to_session'): + request_id = context.get("request_id") + if request_id: + channel.request_to_session[request_id] = receiver + logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}") + + # Send the reply + channel.send(reply, context) + logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}") + else: + logger.error(f"[Scheduler] Failed to create channel: {channel_type}") + except Exception as e: + logger.error(f"[Scheduler] Failed to send result: {e}") + else: + logger.error(f"[Scheduler] Task {task['id']}: No result from agent execution") + + except Exception as e: + logger.error(f"[Scheduler] Failed to execute task via Agent: {e}") + import traceback + logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") + + except Exception as e: + logger.error(f"[Scheduler] Error in _execute_agent_task: {e}") + import traceback + logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") + + +def _execute_send_message(task: dict, agent_bridge): + """ + Execute a send_message action + + Args: + task: Task dictionary + agent_bridge: AgentBridge instance + """ + try: + action = task.get("action", {}) + content = action.get("content", "") + receiver = action.get("receiver") + is_group = action.get("is_group", False) + channel_type = action.get("channel_type", "unknown") + + if not receiver: + logger.error(f"[Scheduler] Task {task['id']}: No receiver specified") + return + + # Create context for sending message + context = Context(ContextType.TEXT, content) + context["receiver"] = receiver + context["isgroup"] = is_group + context["session_id"] = receiver + + # Channel-specific context setup + if channel_type == "web": + # Web channel needs request_id + import uuid + request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}" + context["request_id"] = request_id + logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}") + elif channel_type == "feishu": + # Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to) + # Use chat_id for groups, open_id for private chats + context["receive_id_type"] = "chat_id" if is_group else "open_id" + # Keep isgroup as is, but set msg to None (no original message to reply to) + # Feishu channel will detect this and send as new message instead of reply + context["msg"] = None + logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}") + elif channel_type == "dingtalk": + # DingTalk channel setup + context["msg"] = None + # 如果是单聊,需要传递 sender_staff_id + if not is_group: + sender_staff_id = action.get("dingtalk_sender_staff_id") + if sender_staff_id: + context["dingtalk_sender_staff_id"] = sender_staff_id + logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}") + else: + logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id") + + # Create reply + reply = Reply(ReplyType.TEXT, content) + + # Get channel and send + from channel.channel_factory import create_channel + + try: + channel = create_channel(channel_type) + if channel: + # For web channel, register the request_id to session mapping + if channel_type == "web" and hasattr(channel, 'request_to_session'): + channel.request_to_session[request_id] = receiver + logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}") + + channel.send(reply, context) + logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}") + else: + logger.error(f"[Scheduler] Failed to create channel: {channel_type}") + except Exception as e: + logger.error(f"[Scheduler] Failed to send message: {e}") + import traceback + logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") + + except Exception as e: + logger.error(f"[Scheduler] Error in _execute_send_message: {e}") + import traceback + logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") + + +def _execute_tool_call(task: dict, agent_bridge): + """ + Execute a tool_call action + + Args: + task: Task dictionary + agent_bridge: AgentBridge instance + """ + try: + action = task.get("action", {}) + # Support both old and new field names + tool_name = action.get("call_name") or action.get("tool_name") + tool_params = action.get("call_params") or action.get("tool_params", {}) + result_prefix = action.get("result_prefix", "") + receiver = action.get("receiver") + is_group = action.get("is_group", False) + channel_type = action.get("channel_type", "unknown") + + if not tool_name: + logger.error(f"[Scheduler] Task {task['id']}: No tool_name specified") + return + + if not receiver: + logger.error(f"[Scheduler] Task {task['id']}: No receiver specified") + return + + # Get tool manager and create tool instance + from agent.tools.tool_manager import ToolManager + tool_manager = ToolManager() + tool = tool_manager.create_tool(tool_name) + + if not tool: + logger.error(f"[Scheduler] Task {task['id']}: Tool '{tool_name}' not found") + return + + # Execute tool + logger.info(f"[Scheduler] Task {task['id']}: Executing tool '{tool_name}' with params {tool_params}") + result = tool.execute(tool_params) + + # Get result content + if hasattr(result, 'result'): + content = result.result + else: + content = str(result) + + # Add prefix if specified + if result_prefix: + content = f"{result_prefix}\n\n{content}" + + # Send result as message + context = Context(ContextType.TEXT, content) + context["receiver"] = receiver + context["isgroup"] = is_group + context["session_id"] = receiver + + # Channel-specific context setup + if channel_type == "web": + # Web channel needs request_id + import uuid + request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}" + context["request_id"] = request_id + logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}") + elif channel_type == "feishu": + # Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to) + context["receive_id_type"] = "chat_id" if is_group else "open_id" + context["msg"] = None + logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}") + + reply = Reply(ReplyType.TEXT, content) + + # Get channel and send + from channel.channel_factory import create_channel + + try: + channel = create_channel(channel_type) + if channel: + # For web channel, register the request_id to session mapping + if channel_type == "web" and hasattr(channel, 'request_to_session'): + channel.request_to_session[request_id] = receiver + logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}") + + channel.send(reply, context) + logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}") + else: + logger.error(f"[Scheduler] Failed to create channel: {channel_type}") + except Exception as e: + logger.error(f"[Scheduler] Failed to send tool result: {e}") + + except Exception as e: + logger.error(f"[Scheduler] Error in _execute_tool_call: {e}") + + +def _execute_skill_call(task: dict, agent_bridge): + """ + Execute a skill_call action by asking Agent to run the skill + + Args: + task: Task dictionary + agent_bridge: AgentBridge instance + """ + try: + action = task.get("action", {}) + # Support both old and new field names + skill_name = action.get("call_name") or action.get("skill_name") + skill_params = action.get("call_params") or action.get("skill_params", {}) + result_prefix = action.get("result_prefix", "") + receiver = action.get("receiver") + is_group = action.get("isgroup", False) + channel_type = action.get("channel_type", "unknown") + + if not skill_name: + logger.error(f"[Scheduler] Task {task['id']}: No skill_name specified") + return + + if not receiver: + logger.error(f"[Scheduler] Task {task['id']}: No receiver specified") + return + + logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}") + + # Build a natural language query for the Agent to execute the skill + # Format: "Use skill-name to do something with params" + param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()]) + query = f"Use {skill_name} skill" + if param_str: + query += f" with {param_str}" + + # Create context for Agent + context = Context(ContextType.TEXT, query) + context["receiver"] = receiver + context["isgroup"] = is_group + context["session_id"] = receiver + + # Channel-specific setup + if channel_type == "web": + import uuid + request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}" + context["request_id"] = request_id + elif channel_type == "feishu": + context["receive_id_type"] = "chat_id" if is_group else "open_id" + context["msg"] = None + + # Use Agent to execute the skill + try: + reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=True) + + if reply and reply.content: + content = reply.content + + # Add prefix if specified + if result_prefix: + content = f"{result_prefix}\n\n{content}" + + logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}") + else: + logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution") + + except Exception as e: + logger.error(f"[Scheduler] Failed to execute skill via Agent: {e}") + import traceback + logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") + + except Exception as e: + logger.error(f"[Scheduler] Error in _execute_skill_call: {e}") + import traceback + logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") + + +def attach_scheduler_to_tool(tool, context: Context = None): + """ + Attach scheduler components to a SchedulerTool instance + + Args: + tool: SchedulerTool instance + context: Current context (optional) + """ + if _task_store: + tool.task_store = _task_store + + if context: + tool.current_context = context + + # Also set channel_type from config + channel_type = conf().get("channel_type", "unknown") + if not tool.config: + tool.config = {} + tool.config["channel_type"] = channel_type diff --git a/agent/tools/scheduler/scheduler_service.py b/agent/tools/scheduler/scheduler_service.py new file mode 100644 index 0000000..286fbc6 --- /dev/null +++ b/agent/tools/scheduler/scheduler_service.py @@ -0,0 +1,220 @@ +""" +Background scheduler service for executing scheduled tasks +""" + +import time +import threading +from datetime import datetime, timedelta +from typing import Callable, Optional +from croniter import croniter +from common.log import logger + + +class SchedulerService: + """ + Background service that executes scheduled tasks + """ + + def __init__(self, task_store, execute_callback: Callable): + """ + Initialize scheduler service + + Args: + task_store: TaskStore instance + execute_callback: Function to call when executing a task + """ + self.task_store = task_store + self.execute_callback = execute_callback + self.running = False + self.thread = None + self._lock = threading.Lock() + + def start(self): + """Start the scheduler service""" + with self._lock: + if self.running: + logger.warning("[Scheduler] Service already running") + return + + self.running = True + self.thread = threading.Thread(target=self._run_loop, daemon=True) + self.thread.start() + logger.debug("[Scheduler] Service started") + + def stop(self): + """Stop the scheduler service""" + with self._lock: + if not self.running: + return + + self.running = False + if self.thread: + self.thread.join(timeout=5) + logger.info("[Scheduler] Service stopped") + + def _run_loop(self): + """Main scheduler loop""" + logger.debug("[Scheduler] Scheduler loop started") + + while self.running: + try: + self._check_and_execute_tasks() + except Exception as e: + logger.error(f"[Scheduler] Error in scheduler loop: {e}") + + # Sleep for 30 seconds between checks + time.sleep(30) + + def _check_and_execute_tasks(self): + """Check for due tasks and execute them""" + now = datetime.now() + tasks = self.task_store.list_tasks(enabled_only=True) + + for task in tasks: + try: + # Check if task is due + if self._is_task_due(task, now): + logger.info(f"[Scheduler] Executing task: {task['id']} - {task['name']}") + self._execute_task(task) + + # Update next run time + next_run = self._calculate_next_run(task, now) + if next_run: + self.task_store.update_task(task['id'], { + "next_run_at": next_run.isoformat(), + "last_run_at": now.isoformat() + }) + else: + # One-time task, disable it + self.task_store.update_task(task['id'], { + "enabled": False, + "last_run_at": now.isoformat() + }) + logger.info(f"[Scheduler] One-time task completed and disabled: {task['id']}") + except Exception as e: + logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}") + + def _is_task_due(self, task: dict, now: datetime) -> bool: + """ + Check if a task is due to run + + Args: + task: Task dictionary + now: Current datetime + + Returns: + True if task should run now + """ + next_run_str = task.get("next_run_at") + if not next_run_str: + # Calculate initial next_run_at + next_run = self._calculate_next_run(task, now) + if next_run: + self.task_store.update_task(task['id'], { + "next_run_at": next_run.isoformat() + }) + return False + return False + + try: + next_run = datetime.fromisoformat(next_run_str) + + # Check if task is overdue (e.g., service restart) + if next_run < now: + time_diff = (now - next_run).total_seconds() + + # If overdue by more than 5 minutes, skip this run and schedule next + if time_diff > 300: # 5 minutes + logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run") + + # For one-time tasks, disable them + schedule = task.get("schedule", {}) + if schedule.get("type") == "once": + self.task_store.update_task(task['id'], { + "enabled": False, + "last_run_at": now.isoformat() + }) + logger.info(f"[Scheduler] One-time task {task['id']} expired, disabled") + return False + + # For recurring tasks, calculate next run from now + next_next_run = self._calculate_next_run(task, now) + if next_next_run: + self.task_store.update_task(task['id'], { + "next_run_at": next_next_run.isoformat() + }) + logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}") + return False + + return now >= next_run + except: + return False + + def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]: + """ + Calculate next run time for a task + + Args: + task: Task dictionary + from_time: Calculate from this time + + Returns: + Next run datetime or None for one-time tasks + """ + schedule = task.get("schedule", {}) + schedule_type = schedule.get("type") + + if schedule_type == "cron": + # Cron expression + expression = schedule.get("expression") + if not expression: + return None + + try: + cron = croniter(expression, from_time) + return cron.get_next(datetime) + except Exception as e: + logger.error(f"[Scheduler] Invalid cron expression '{expression}': {e}") + return None + + elif schedule_type == "interval": + # Interval in seconds + seconds = schedule.get("seconds", 0) + if seconds <= 0: + return None + return from_time + timedelta(seconds=seconds) + + elif schedule_type == "once": + # One-time task at specific time + run_at_str = schedule.get("run_at") + if not run_at_str: + return None + + try: + run_at = datetime.fromisoformat(run_at_str) + # Only return if in the future + if run_at > from_time: + return run_at + except: + pass + return None + + return None + + def _execute_task(self, task: dict): + """ + Execute a task + + Args: + task: Task dictionary + """ + try: + # Call the execute callback + self.execute_callback(task) + except Exception as e: + logger.error(f"[Scheduler] Error executing task {task['id']}: {e}") + # Update task with error + self.task_store.update_task(task['id'], { + "last_error": str(e), + "last_error_at": datetime.now().isoformat() + }) diff --git a/agent/tools/scheduler/scheduler_tool.py b/agent/tools/scheduler/scheduler_tool.py new file mode 100644 index 0000000..9d961c3 --- /dev/null +++ b/agent/tools/scheduler/scheduler_tool.py @@ -0,0 +1,442 @@ +""" +Scheduler tool for creating and managing scheduled tasks +""" + +import uuid +from datetime import datetime +from typing import Any, Dict, Optional +from croniter import croniter + +from agent.tools.base_tool import BaseTool, ToolResult +from bridge.context import Context, ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger + + +class SchedulerTool(BaseTool): + """ + Tool for managing scheduled tasks (reminders, notifications, etc.) + """ + + name: str = "scheduler" + description: str = ( + "创建、查询和管理定时任务。支持固定消息和AI任务两种类型。\n\n" + "使用方法:\n" + "- 创建:action='create', name='任务名', message/ai_task='内容', schedule_type='once/interval/cron', schedule_value='...'\n" + "- 查询:action='list' / action='get', task_id='任务ID'\n" + "- 管理:action='delete/enable/disable', task_id='任务ID'\n\n" + "调度类型:\n" + "- once: 一次性任务,支持相对时间(+5s,+10m,+1h,+1d)或ISO时间\n" + "- interval: 固定间隔(秒),如3600表示每小时\n" + "- cron: cron表达式,如'0 8 * * *'表示每天8点\n\n" + "注意:'X秒后'用once+相对时间,'每X秒'用interval" + ) + params: dict = { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["create", "list", "get", "delete", "enable", "disable"], + "description": "操作类型: create(创建), list(列表), get(查询), delete(删除), enable(启用), disable(禁用)" + }, + "task_id": { + "type": "string", + "description": "任务ID (用于 get/delete/enable/disable 操作)" + }, + "name": { + "type": "string", + "description": "任务名称 (用于 create 操作)" + }, + "message": { + "type": "string", + "description": "固定消息内容 (与ai_task二选一)" + }, + "ai_task": { + "type": "string", + "description": "AI任务描述 (与message二选一),如'搜索今日新闻'、'查询天气'" + }, + "schedule_type": { + "type": "string", + "enum": ["cron", "interval", "once"], + "description": "调度类型 (用于 create 操作): cron(cron表达式), interval(固定间隔秒数), once(一次性)" + }, + "schedule_value": { + "type": "string", + "description": "调度值: cron表达式/间隔秒数/时间(+5s,+10m,+1h或ISO格式)" + } + }, + "required": ["action"] + } + + def __init__(self, config: dict = None): + super().__init__() + self.config = config or {} + + # Will be set by agent bridge + self.task_store = None + self.current_context = None + + def execute(self, params: dict) -> ToolResult: + """ + Execute scheduler operations + + Args: + params: Dictionary containing: + - action: Operation type (create/list/get/delete/enable/disable) + - Other parameters depending on action + + Returns: + ToolResult object + """ + # Extract parameters + action = params.get("action") + kwargs = params + + if not self.task_store: + return ToolResult.fail("错误: 定时任务系统未初始化") + + try: + if action == "create": + result = self._create_task(**kwargs) + return ToolResult.success(result) + elif action == "list": + result = self._list_tasks(**kwargs) + return ToolResult.success(result) + elif action == "get": + result = self._get_task(**kwargs) + return ToolResult.success(result) + elif action == "delete": + result = self._delete_task(**kwargs) + return ToolResult.success(result) + elif action == "enable": + result = self._enable_task(**kwargs) + return ToolResult.success(result) + elif action == "disable": + result = self._disable_task(**kwargs) + return ToolResult.success(result) + else: + return ToolResult.fail(f"未知操作: {action}") + except Exception as e: + logger.error(f"[SchedulerTool] Error: {e}") + return ToolResult.fail(f"操作失败: {str(e)}") + + def _create_task(self, **kwargs) -> str: + """Create a new scheduled task""" + name = kwargs.get("name") + message = kwargs.get("message") + ai_task = kwargs.get("ai_task") + schedule_type = kwargs.get("schedule_type") + schedule_value = kwargs.get("schedule_value") + + # Validate required fields + if not name: + return "错误: 缺少任务名称 (name)" + + # Check that exactly one of message/ai_task is provided + if not message and not ai_task: + return "错误: 必须提供 message(固定消息)或 ai_task(AI任务)之一" + if message and ai_task: + return "错误: message 和 ai_task 只能提供其中一个" + + if not schedule_type: + return "错误: 缺少调度类型 (schedule_type)" + if not schedule_value: + return "错误: 缺少调度值 (schedule_value)" + + # Validate schedule + schedule = self._parse_schedule(schedule_type, schedule_value) + if not schedule: + return f"错误: 无效的调度配置 - type: {schedule_type}, value: {schedule_value}" + + # Get context info for receiver + if not self.current_context: + return "错误: 无法获取当前对话上下文" + + context = self.current_context + + # Create task + task_id = str(uuid.uuid4())[:8] + + # Build action based on message or ai_task + if message: + action = { + "type": "send_message", + "content": message, + "receiver": context.get("receiver"), + "receiver_name": self._get_receiver_name(context), + "is_group": context.get("isgroup", False), + "channel_type": self.config.get("channel_type", "unknown") + } + else: # ai_task + action = { + "type": "agent_task", + "task_description": ai_task, + "receiver": context.get("receiver"), + "receiver_name": self._get_receiver_name(context), + "is_group": context.get("isgroup", False), + "channel_type": self.config.get("channel_type", "unknown") + } + + # 针对钉钉单聊,额外存储 sender_staff_id + msg = context.kwargs.get("msg") + if msg and hasattr(msg, 'sender_staff_id') and not context.get("isgroup", False): + action["dingtalk_sender_staff_id"] = msg.sender_staff_id + + task_data = { + "id": task_id, + "name": name, + "enabled": True, + "created_at": datetime.now().isoformat(), + "updated_at": datetime.now().isoformat(), + "schedule": schedule, + "action": action + } + + # Calculate initial next_run_at + next_run = self._calculate_next_run(task_data) + if next_run: + task_data["next_run_at"] = next_run.isoformat() + + # Save task + self.task_store.add_task(task_data) + + # Format response + schedule_desc = self._format_schedule_description(schedule) + receiver_desc = task_data["action"]["receiver_name"] or task_data["action"]["receiver"] + + if message: + content_desc = f"💬 固定消息: {message}" + else: + content_desc = f"🤖 AI任务: {ai_task}" + + return ( + f"✅ 定时任务创建成功\n\n" + f"📋 任务ID: {task_id}\n" + f"📝 名称: {name}\n" + f"⏰ 调度: {schedule_desc}\n" + f"👤 接收者: {receiver_desc}\n" + f"{content_desc}\n" + f"🕐 下次执行: {next_run.strftime('%Y-%m-%d %H:%M:%S') if next_run else '未知'}" + ) + + def _list_tasks(self, **kwargs) -> str: + """List all tasks""" + tasks = self.task_store.list_tasks() + + if not tasks: + return "📋 暂无定时任务" + + lines = [f"📋 定时任务列表 (共 {len(tasks)} 个)\n"] + + for task in tasks: + status = "✅" if task.get("enabled", True) else "❌" + schedule_desc = self._format_schedule_description(task.get("schedule", {})) + next_run = task.get("next_run_at") + next_run_str = datetime.fromisoformat(next_run).strftime('%m-%d %H:%M') if next_run else "未知" + + lines.append( + f"{status} [{task['id']}] {task['name']}\n" + f" ⏰ {schedule_desc} | 下次: {next_run_str}" + ) + + return "\n".join(lines) + + def _get_task(self, **kwargs) -> str: + """Get task details""" + task_id = kwargs.get("task_id") + if not task_id: + return "错误: 缺少任务ID (task_id)" + + task = self.task_store.get_task(task_id) + if not task: + return f"错误: 任务 '{task_id}' 不存在" + + status = "启用" if task.get("enabled", True) else "禁用" + schedule_desc = self._format_schedule_description(task.get("schedule", {})) + action = task.get("action", {}) + next_run = task.get("next_run_at") + next_run_str = datetime.fromisoformat(next_run).strftime('%Y-%m-%d %H:%M:%S') if next_run else "未知" + last_run = task.get("last_run_at") + last_run_str = datetime.fromisoformat(last_run).strftime('%Y-%m-%d %H:%M:%S') if last_run else "从未执行" + + return ( + f"📋 任务详情\n\n" + f"ID: {task['id']}\n" + f"名称: {task['name']}\n" + f"状态: {status}\n" + f"调度: {schedule_desc}\n" + f"接收者: {action.get('receiver_name', action.get('receiver'))}\n" + f"消息: {action.get('content')}\n" + f"下次执行: {next_run_str}\n" + f"上次执行: {last_run_str}\n" + f"创建时间: {datetime.fromisoformat(task['created_at']).strftime('%Y-%m-%d %H:%M:%S')}" + ) + + def _delete_task(self, **kwargs) -> str: + """Delete a task""" + task_id = kwargs.get("task_id") + if not task_id: + return "错误: 缺少任务ID (task_id)" + + task = self.task_store.get_task(task_id) + if not task: + return f"错误: 任务 '{task_id}' 不存在" + + self.task_store.delete_task(task_id) + return f"✅ 任务 '{task['name']}' ({task_id}) 已删除" + + def _enable_task(self, **kwargs) -> str: + """Enable a task""" + task_id = kwargs.get("task_id") + if not task_id: + return "错误: 缺少任务ID (task_id)" + + task = self.task_store.get_task(task_id) + if not task: + return f"错误: 任务 '{task_id}' 不存在" + + self.task_store.enable_task(task_id, True) + return f"✅ 任务 '{task['name']}' ({task_id}) 已启用" + + def _disable_task(self, **kwargs) -> str: + """Disable a task""" + task_id = kwargs.get("task_id") + if not task_id: + return "错误: 缺少任务ID (task_id)" + + task = self.task_store.get_task(task_id) + if not task: + return f"错误: 任务 '{task_id}' 不存在" + + self.task_store.enable_task(task_id, False) + return f"✅ 任务 '{task['name']}' ({task_id}) 已禁用" + + def _parse_schedule(self, schedule_type: str, schedule_value: str) -> Optional[dict]: + """Parse and validate schedule configuration""" + try: + if schedule_type == "cron": + # Validate cron expression + croniter(schedule_value) + return {"type": "cron", "expression": schedule_value} + + elif schedule_type == "interval": + # Parse interval in seconds + seconds = int(schedule_value) + if seconds <= 0: + return None + return {"type": "interval", "seconds": seconds} + + elif schedule_type == "once": + # Parse datetime - support both relative and absolute time + + # Check if it's relative time (e.g., "+5s", "+10m", "+1h", "+1d") + if schedule_value.startswith("+"): + import re + match = re.match(r'\+(\d+)([smhd])', schedule_value) + if match: + amount = int(match.group(1)) + unit = match.group(2) + + from datetime import timedelta + now = datetime.now() + + if unit == 's': # seconds + target_time = now + timedelta(seconds=amount) + elif unit == 'm': # minutes + target_time = now + timedelta(minutes=amount) + elif unit == 'h': # hours + target_time = now + timedelta(hours=amount) + elif unit == 'd': # days + target_time = now + timedelta(days=amount) + else: + return None + + return {"type": "once", "run_at": target_time.isoformat()} + else: + logger.error(f"[SchedulerTool] Invalid relative time format: {schedule_value}") + return None + else: + # Absolute time in ISO format + datetime.fromisoformat(schedule_value) + return {"type": "once", "run_at": schedule_value} + + except Exception as e: + logger.error(f"[SchedulerTool] Invalid schedule: {e}") + return None + + return None + + def _calculate_next_run(self, task: dict) -> Optional[datetime]: + """Calculate next run time for a task""" + schedule = task.get("schedule", {}) + schedule_type = schedule.get("type") + now = datetime.now() + + if schedule_type == "cron": + expression = schedule.get("expression") + cron = croniter(expression, now) + return cron.get_next(datetime) + + elif schedule_type == "interval": + seconds = schedule.get("seconds", 0) + from datetime import timedelta + return now + timedelta(seconds=seconds) + + elif schedule_type == "once": + run_at_str = schedule.get("run_at") + return datetime.fromisoformat(run_at_str) + + return None + + def _format_schedule_description(self, schedule: dict) -> str: + """Format schedule as human-readable description""" + schedule_type = schedule.get("type") + + if schedule_type == "cron": + expr = schedule.get("expression", "") + # Try to provide friendly description + if expr == "0 9 * * *": + return "每天 9:00" + elif expr == "0 */1 * * *": + return "每小时" + elif expr == "*/30 * * * *": + return "每30分钟" + else: + return f"Cron: {expr}" + + elif schedule_type == "interval": + seconds = schedule.get("seconds", 0) + if seconds >= 86400: + days = seconds // 86400 + return f"每 {days} 天" + elif seconds >= 3600: + hours = seconds // 3600 + return f"每 {hours} 小时" + elif seconds >= 60: + minutes = seconds // 60 + return f"每 {minutes} 分钟" + else: + return f"每 {seconds} 秒" + + elif schedule_type == "once": + run_at = schedule.get("run_at", "") + try: + dt = datetime.fromisoformat(run_at) + return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})" + except: + return "一次性" + + return "未知" + + def _get_receiver_name(self, context: Context) -> str: + """Get receiver name from context""" + try: + msg = context.get("msg") + if msg: + if context.get("isgroup"): + return msg.other_user_nickname or "群聊" + else: + return msg.from_user_nickname or "用户" + except: + pass + return "未知" diff --git a/agent/tools/scheduler/task_store.py b/agent/tools/scheduler/task_store.py new file mode 100644 index 0000000..55e84a1 --- /dev/null +++ b/agent/tools/scheduler/task_store.py @@ -0,0 +1,200 @@ +""" +Task storage management for scheduler +""" + +import json +import os +import threading +from datetime import datetime +from typing import Dict, List, Optional +from pathlib import Path + + +class TaskStore: + """ + Manages persistent storage of scheduled tasks + """ + + def __init__(self, store_path: str = None): + """ + Initialize task store + + Args: + store_path: Path to tasks.json file. Defaults to ~/cow/scheduler/tasks.json + """ + if store_path is None: + # Default to ~/cow/scheduler/tasks.json + home = os.path.expanduser("~") + store_path = os.path.join(home, "cow", "scheduler", "tasks.json") + + self.store_path = store_path + self.lock = threading.Lock() + self._ensure_store_dir() + + def _ensure_store_dir(self): + """Ensure the storage directory exists""" + store_dir = os.path.dirname(self.store_path) + os.makedirs(store_dir, exist_ok=True) + + def load_tasks(self) -> Dict[str, dict]: + """ + Load all tasks from storage + + Returns: + Dictionary of task_id -> task_data + """ + with self.lock: + if not os.path.exists(self.store_path): + return {} + + try: + with open(self.store_path, 'r', encoding='utf-8') as f: + data = json.load(f) + return data.get("tasks", {}) + except Exception as e: + print(f"Error loading tasks: {e}") + return {} + + def save_tasks(self, tasks: Dict[str, dict]): + """ + Save all tasks to storage + + Args: + tasks: Dictionary of task_id -> task_data + """ + with self.lock: + try: + # Create backup + if os.path.exists(self.store_path): + backup_path = f"{self.store_path}.bak" + try: + with open(self.store_path, 'r') as src: + with open(backup_path, 'w') as dst: + dst.write(src.read()) + except: + pass + + # Save tasks + data = { + "version": 1, + "updated_at": datetime.now().isoformat(), + "tasks": tasks + } + + with open(self.store_path, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"Error saving tasks: {e}") + raise + + def add_task(self, task: dict) -> bool: + """ + Add a new task + + Args: + task: Task data dictionary + + Returns: + True if successful + """ + tasks = self.load_tasks() + task_id = task.get("id") + + if not task_id: + raise ValueError("Task must have an 'id' field") + + if task_id in tasks: + raise ValueError(f"Task with id '{task_id}' already exists") + + tasks[task_id] = task + self.save_tasks(tasks) + return True + + def update_task(self, task_id: str, updates: dict) -> bool: + """ + Update an existing task + + Args: + task_id: Task ID + updates: Dictionary of fields to update + + Returns: + True if successful + """ + tasks = self.load_tasks() + + if task_id not in tasks: + raise ValueError(f"Task '{task_id}' not found") + + # Update fields + tasks[task_id].update(updates) + tasks[task_id]["updated_at"] = datetime.now().isoformat() + + self.save_tasks(tasks) + return True + + def delete_task(self, task_id: str) -> bool: + """ + Delete a task + + Args: + task_id: Task ID + + Returns: + True if successful + """ + tasks = self.load_tasks() + + if task_id not in tasks: + raise ValueError(f"Task '{task_id}' not found") + + del tasks[task_id] + self.save_tasks(tasks) + return True + + def get_task(self, task_id: str) -> Optional[dict]: + """ + Get a specific task + + Args: + task_id: Task ID + + Returns: + Task data or None if not found + """ + tasks = self.load_tasks() + return tasks.get(task_id) + + def list_tasks(self, enabled_only: bool = False) -> List[dict]: + """ + List all tasks + + Args: + enabled_only: If True, only return enabled tasks + + Returns: + List of task dictionaries + """ + tasks = self.load_tasks() + task_list = list(tasks.values()) + + if enabled_only: + task_list = [t for t in task_list if t.get("enabled", True)] + + # Sort by next_run_at + task_list.sort(key=lambda t: t.get("next_run_at", float('inf'))) + + return task_list + + def enable_task(self, task_id: str, enabled: bool = True) -> bool: + """ + Enable or disable a task + + Args: + task_id: Task ID + enabled: True to enable, False to disable + + Returns: + True if successful + """ + return self.update_task(task_id, {"enabled": enabled}) diff --git a/agent/tools/send/__init__.py b/agent/tools/send/__init__.py new file mode 100644 index 0000000..b76702a --- /dev/null +++ b/agent/tools/send/__init__.py @@ -0,0 +1,3 @@ +from .send import Send + +__all__ = ['Send'] diff --git a/agent/tools/send/send.py b/agent/tools/send/send.py new file mode 100644 index 0000000..a778b74 --- /dev/null +++ b/agent/tools/send/send.py @@ -0,0 +1,159 @@ +""" +Send tool - Send files to the user +""" + +import os +from typing import Dict, Any +from pathlib import Path + +from agent.tools.base_tool import BaseTool, ToolResult + + +class Send(BaseTool): + """Tool for sending files to the user""" + + name: str = "send" + description: str = "Send a file (image, video, audio, document) to the user. Use this when the user explicitly asks to send/share a file." + + params: dict = { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file to send. Can be absolute path or relative to workspace." + }, + "message": { + "type": "string", + "description": "Optional message to accompany the file" + } + }, + "required": ["path"] + } + + def __init__(self, config: dict = None): + self.config = config or {} + self.cwd = self.config.get("cwd", os.getcwd()) + + # Supported file types + self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'} + self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'} + self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'} + self.document_extensions = {'.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.txt', '.md'} + + def execute(self, args: Dict[str, Any]) -> ToolResult: + """ + Execute file send operation + + :param args: Contains file path and optional message + :return: File metadata for channel to send + """ + path = args.get("path", "").strip() + message = args.get("message", "") + + if not path: + return ToolResult.fail("Error: path parameter is required") + + # Resolve path + absolute_path = self._resolve_path(path) + + # Check if file exists + if not os.path.exists(absolute_path): + return ToolResult.fail(f"Error: File not found: {path}") + + # Check if readable + if not os.access(absolute_path, os.R_OK): + return ToolResult.fail(f"Error: File is not readable: {path}") + + # Get file info + file_ext = Path(absolute_path).suffix.lower() + file_size = os.path.getsize(absolute_path) + file_name = Path(absolute_path).name + + # Determine file type + if file_ext in self.image_extensions: + file_type = "image" + mime_type = self._get_image_mime_type(file_ext) + elif file_ext in self.video_extensions: + file_type = "video" + mime_type = self._get_video_mime_type(file_ext) + elif file_ext in self.audio_extensions: + file_type = "audio" + mime_type = self._get_audio_mime_type(file_ext) + elif file_ext in self.document_extensions: + file_type = "document" + mime_type = self._get_document_mime_type(file_ext) + else: + file_type = "file" + mime_type = "application/octet-stream" + + # Return file_to_send metadata + result = { + "type": "file_to_send", + "file_type": file_type, + "path": absolute_path, + "file_name": file_name, + "mime_type": mime_type, + "size": file_size, + "size_formatted": self._format_size(file_size), + "message": message or f"正在发送 {file_name}" + } + + return ToolResult.success(result) + + def _resolve_path(self, path: str) -> str: + """Resolve path to absolute path""" + path = os.path.expanduser(path) + if os.path.isabs(path): + return path + return os.path.abspath(os.path.join(self.cwd, path)) + + def _get_image_mime_type(self, ext: str) -> str: + """Get MIME type for image""" + mime_map = { + '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', + '.png': 'image/png', '.gif': 'image/gif', + '.webp': 'image/webp', '.bmp': 'image/bmp', + '.svg': 'image/svg+xml', '.ico': 'image/x-icon' + } + return mime_map.get(ext, 'image/jpeg') + + def _get_video_mime_type(self, ext: str) -> str: + """Get MIME type for video""" + mime_map = { + '.mp4': 'video/mp4', '.avi': 'video/x-msvideo', + '.mov': 'video/quicktime', '.mkv': 'video/x-matroska', + '.webm': 'video/webm', '.flv': 'video/x-flv' + } + return mime_map.get(ext, 'video/mp4') + + def _get_audio_mime_type(self, ext: str) -> str: + """Get MIME type for audio""" + mime_map = { + '.mp3': 'audio/mpeg', '.wav': 'audio/wav', + '.ogg': 'audio/ogg', '.m4a': 'audio/mp4', + '.flac': 'audio/flac', '.aac': 'audio/aac' + } + return mime_map.get(ext, 'audio/mpeg') + + def _get_document_mime_type(self, ext: str) -> str: + """Get MIME type for document""" + mime_map = { + '.pdf': 'application/pdf', + '.doc': 'application/msword', + '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + '.xls': 'application/vnd.ms-excel', + '.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + '.ppt': 'application/vnd.ms-powerpoint', + '.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + '.txt': 'text/plain', + '.md': 'text/markdown' + } + return mime_map.get(ext, 'application/octet-stream') + + def _format_size(self, size_bytes: int) -> str: + """Format file size in human-readable format""" + for unit in ['B', 'KB', 'MB', 'GB']: + if size_bytes < 1024.0: + return f"{size_bytes:.1f}{unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.1f}TB" diff --git a/agent/tools/web_fetch/README.md b/agent/tools/web_fetch/README.md deleted file mode 100644 index 6fc192f..0000000 --- a/agent/tools/web_fetch/README.md +++ /dev/null @@ -1,212 +0,0 @@ -# WebFetch Tool - -免费的网页抓取工具,无需 API Key,可直接抓取网页内容并提取可读文本。 - -## 功能特性 - -- ✅ **完全免费** - 无需任何 API Key -- 🌐 **智能提取** - 自动提取网页主要内容 -- 📝 **格式转换** - 支持 HTML → Markdown/Text -- 🚀 **高性能** - 内置请求重试和超时控制 -- 🎯 **智能降级** - 优先使用 Readability,可降级到基础提取 - -## 安装依赖 - -### 基础功能(必需) -```bash -pip install requests -``` - -### 增强功能(推荐) -```bash -# 安装 readability-lxml 以获得更好的内容提取效果 -pip install readability-lxml - -# 安装 html2text 以获得更好的 Markdown 转换 -pip install html2text -``` - -## 使用方法 - -### 1. 在代码中使用 - -```python -from agent.tools.web_fetch import WebFetch - -# 创建工具实例 -tool = WebFetch() - -# 抓取网页(默认返回 Markdown 格式) -result = tool.execute({ - "url": "https://example.com" -}) - -# 抓取并转换为纯文本 -result = tool.execute({ - "url": "https://example.com", - "extract_mode": "text", - "max_chars": 5000 -}) - -if result.status == "success": - data = result.result - print(f"标题: {data['title']}") - print(f"内容: {data['text']}") -``` - -### 2. 在 Agent 中使用 - -工具会自动加载到 Agent 的工具列表中: - -```python -from agent.tools import WebFetch - -tools = [ - WebFetch(), - # ... 其他工具 -] - -agent = create_agent(tools=tools) -``` - -### 3. 通过 Skills 使用 - -创建一个 skill 文件 `skills/web-fetch/SKILL.md`: - -```markdown ---- -name: web-fetch -emoji: 🌐 -always: true ---- - -# 网页内容获取 - -使用 web_fetch 工具获取网页内容。 - -## 使用场景 - -- 需要读取某个网页的内容 -- 需要提取文章正文 -- 需要获取网页信息 - -## 示例 - - -用户: 帮我看看 https://example.com 这个网页讲了什么 -助手: - https://example.com - markdown - - -``` - -## 参数说明 - -| 参数 | 类型 | 必需 | 默认值 | 说明 | -|------|------|------|--------|------| -| `url` | string | ✅ | - | 要抓取的 URL(http/https) | -| `extract_mode` | string | ❌ | `markdown` | 提取模式:`markdown` 或 `text` | -| `max_chars` | integer | ❌ | `50000` | 最大返回字符数(最小 100) | - -## 返回结果 - -```python -{ - "url": "https://example.com", # 最终 URL(处理重定向后) - "status": 200, # HTTP 状态码 - "content_type": "text/html", # 内容类型 - "title": "Example Domain", # 页面标题 - "extractor": "readability", # 提取器:readability/basic/raw - "extract_mode": "markdown", # 提取模式 - "text": "# Example Domain\n\n...", # 提取的文本内容 - "length": 1234, # 文本长度 - "truncated": false, # 是否被截断 - "warning": "..." # 警告信息(如果有) -} -``` - -## 与其他搜索工具的对比 - -| 工具 | 需要 API Key | 功能 | 成本 | -|------|-------------|------|------| -| `web_fetch` | ❌ 不需要 | 抓取指定 URL 的内容 | 免费 | -| `web_search` (Brave) | ✅ 需要 | 搜索引擎查询 | 有免费额度 | -| `web_search` (Perplexity) | ✅ 需要 | AI 搜索 + 引用 | 付费 | -| `browser` | ❌ 不需要 | 完整浏览器自动化 | 免费但资源占用大 | -| `google_search` | ✅ 需要 | Google 搜索 API | 付费 | - -## 技术细节 - -### 内容提取策略 - -1. **Readability 模式**(推荐) - - 使用 Mozilla 的 Readability 算法 - - 自动识别文章主体内容 - - 过滤广告、导航栏等噪音 - -2. **Basic 模式**(降级) - - 简单的 HTML 标签清理 - - 正则表达式提取文本 - - 适用于简单页面 - -3. **Raw 模式** - - 用于非 HTML 内容 - - 直接返回原始内容 - -### 错误处理 - -工具会自动处理以下情况: -- ✅ HTTP 重定向(最多 3 次) -- ✅ 请求超时(默认 30 秒) -- ✅ 网络错误自动重试 -- ✅ 内容提取失败降级 - -## 测试 - -运行测试脚本: - -```bash -cd agent/tools/web_fetch -python test_web_fetch.py -``` - -## 配置选项 - -在创建工具时可以传入配置: - -```python -tool = WebFetch(config={ - "timeout": 30, # 请求超时时间(秒) - "max_redirects": 3, # 最大重定向次数 - "user_agent": "..." # 自定义 User-Agent -}) -``` - -## 常见问题 - -### Q: 为什么推荐安装 readability-lxml? - -A: readability-lxml 提供更好的内容提取质量,能够: -- 自动识别文章主体 -- 过滤广告和导航栏 -- 保留文章结构 - -没有它也能工作,但提取质量会下降。 - -### Q: 与 clawdbot 的 web_fetch 有什么区别? - -A: 本实现参考了 clawdbot 的设计,主要区别: -- Python 实现(clawdbot 是 TypeScript) -- 简化了一些高级特性(如 Firecrawl 集成) -- 保留了核心的免费功能 -- 更容易集成到现有项目 - -### Q: 可以抓取需要登录的页面吗? - -A: 当前版本不支持。如需抓取需要登录的页面,请使用 `browser` 工具。 - -## 参考 - -- [Mozilla Readability](https://github.com/mozilla/readability) -- [Clawdbot Web Tools](https://github.com/moltbot/moltbot) diff --git a/agent/tools/web_fetch/__init__.py b/agent/tools/web_fetch/__init__.py deleted file mode 100644 index 545f87c..0000000 --- a/agent/tools/web_fetch/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .web_fetch import WebFetch - -__all__ = ['WebFetch'] diff --git a/agent/tools/web_fetch/install_deps.sh b/agent/tools/web_fetch/install_deps.sh deleted file mode 100644 index 3c2a553..0000000 --- a/agent/tools/web_fetch/install_deps.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash - -# WebFetch 工具依赖安装脚本 - -echo "==================================" -echo "WebFetch 工具依赖安装" -echo "==================================" -echo "" - -# 检查 Python 版本 -python_version=$(python3 --version 2>&1 | awk '{print $2}') -echo "✓ Python 版本: $python_version" -echo "" - -# 安装基础依赖 -echo "📦 安装基础依赖..." -python3 -m pip install requests - -# 检查是否成功 -if [ $? -eq 0 ]; then - echo "✅ requests 安装成功" -else - echo "❌ requests 安装失败" - exit 1 -fi - -echo "" - -# 安装推荐依赖 -echo "📦 安装推荐依赖(提升内容提取质量)..." -python3 -m pip install readability-lxml html2text - -# 检查是否成功 -if [ $? -eq 0 ]; then - echo "✅ readability-lxml 和 html2text 安装成功" -else - echo "⚠️ 推荐依赖安装失败,但不影响基础功能" -fi - -echo "" -echo "==================================" -echo "安装完成!" -echo "==================================" -echo "" -echo "运行测试:" -echo " python3 agent/tools/web_fetch/test_web_fetch.py" -echo "" diff --git a/agent/tools/web_fetch/web_fetch.py b/agent/tools/web_fetch/web_fetch.py deleted file mode 100644 index b87b95e..0000000 --- a/agent/tools/web_fetch/web_fetch.py +++ /dev/null @@ -1,365 +0,0 @@ -""" -Web Fetch tool - Fetch and extract readable content from URLs -Supports HTML to Markdown/Text conversion using Mozilla's Readability -""" - -import os -import re -from typing import Dict, Any, Optional -from urllib.parse import urlparse -import requests -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry - -from agent.tools.base_tool import BaseTool, ToolResult -from common.log import logger - - -class WebFetch(BaseTool): - """Tool for fetching and extracting readable content from web pages""" - - name: str = "web_fetch" - description: str = "Fetch and extract readable content from a URL (HTML → markdown/text). Use for lightweight page access without browser automation. Returns title, content, and metadata." - - params: dict = { - "type": "object", - "properties": { - "url": { - "type": "string", - "description": "HTTP or HTTPS URL to fetch" - }, - "extract_mode": { - "type": "string", - "description": "Extraction mode: 'markdown' (default) or 'text'", - "enum": ["markdown", "text"], - "default": "markdown" - }, - "max_chars": { - "type": "integer", - "description": "Maximum characters to return (default: 50000)", - "minimum": 100, - "default": 50000 - } - }, - "required": ["url"] - } - - def __init__(self, config: dict = None): - self.config = config or {} - self.timeout = self.config.get("timeout", 20) - self.max_redirects = self.config.get("max_redirects", 3) - self.user_agent = self.config.get( - "user_agent", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36" - ) - - # Setup session with retry strategy - self.session = self._create_session() - - # Check if readability-lxml is available - self.readability_available = self._check_readability() - - def _create_session(self) -> requests.Session: - """Create a requests session with retry strategy""" - session = requests.Session() - - # Retry strategy - handles failed requests, not redirects - retry_strategy = Retry( - total=3, - backoff_factor=1, - status_forcelist=[429, 500, 502, 503, 504], - allowed_methods=["GET", "HEAD"] - ) - - # HTTPAdapter handles retries; requests handles redirects via allow_redirects - adapter = HTTPAdapter(max_retries=retry_strategy) - session.mount("http://", adapter) - session.mount("https://", adapter) - - # Set max redirects on session - session.max_redirects = self.max_redirects - - return session - - def _check_readability(self) -> bool: - """Check if readability-lxml is available""" - try: - from readability import Document - return True - except ImportError: - logger.warning( - "readability-lxml not installed. Install with: pip install readability-lxml\n" - "Falling back to basic HTML extraction." - ) - return False - - def execute(self, args: Dict[str, Any]) -> ToolResult: - """ - Execute web fetch operation - - :param args: Contains url, extract_mode, and max_chars parameters - :return: Extracted content or error message - """ - url = args.get("url", "").strip() - extract_mode = args.get("extract_mode", "markdown").lower() - max_chars = args.get("max_chars", 50000) - - if not url: - return ToolResult.fail("Error: url parameter is required") - - # Validate URL - if not self._is_valid_url(url): - return ToolResult.fail(f"Error: Invalid URL (must be http or https): {url}") - - # Validate extract_mode - if extract_mode not in ["markdown", "text"]: - extract_mode = "markdown" - - # Validate max_chars - if not isinstance(max_chars, int) or max_chars < 100: - max_chars = 50000 - - try: - # Fetch the URL - response = self._fetch_url(url) - - # Extract content - result = self._extract_content( - html=response.text, - url=response.url, - status_code=response.status_code, - content_type=response.headers.get("content-type", ""), - extract_mode=extract_mode, - max_chars=max_chars - ) - - return ToolResult.success(result) - - except requests.exceptions.Timeout: - return ToolResult.fail(f"Error: Request timeout after {self.timeout} seconds") - except requests.exceptions.TooManyRedirects: - return ToolResult.fail(f"Error: Too many redirects (limit: {self.max_redirects})") - except requests.exceptions.RequestException as e: - return ToolResult.fail(f"Error fetching URL: {str(e)}") - except Exception as e: - logger.error(f"Web fetch error: {e}", exc_info=True) - return ToolResult.fail(f"Error: {str(e)}") - - def _is_valid_url(self, url: str) -> bool: - """Validate URL format""" - try: - result = urlparse(url) - return result.scheme in ["http", "https"] and bool(result.netloc) - except Exception: - return False - - def _fetch_url(self, url: str) -> requests.Response: - """ - Fetch URL with proper headers and error handling - - :param url: URL to fetch - :return: Response object - """ - headers = { - "User-Agent": self.user_agent, - "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", - "Accept-Language": "en-US,en;q=0.9,zh-CN,zh;q=0.8", - "Accept-Encoding": "gzip, deflate", - "Connection": "keep-alive", - } - - # Note: requests library handles redirects automatically - # The max_redirects is set in the session's adapter (HTTPAdapter) - response = self.session.get( - url, - headers=headers, - timeout=self.timeout, - allow_redirects=True - ) - - response.raise_for_status() - return response - - def _extract_content( - self, - html: str, - url: str, - status_code: int, - content_type: str, - extract_mode: str, - max_chars: int - ) -> Dict[str, Any]: - """ - Extract readable content from HTML - - :param html: HTML content - :param url: Original URL - :param status_code: HTTP status code - :param content_type: Content type header - :param extract_mode: 'markdown' or 'text' - :param max_chars: Maximum characters to return - :return: Extracted content and metadata - """ - # Check content type - if "text/html" not in content_type.lower(): - # Non-HTML content - text = html[:max_chars] - truncated = len(html) > max_chars - - return { - "url": url, - "status": status_code, - "content_type": content_type, - "extractor": "raw", - "text": text, - "length": len(text), - "truncated": truncated, - "message": f"Non-HTML content (type: {content_type})" - } - - # Extract readable content from HTML - if self.readability_available: - return self._extract_with_readability( - html, url, status_code, content_type, extract_mode, max_chars - ) - else: - return self._extract_basic( - html, url, status_code, content_type, extract_mode, max_chars - ) - - def _extract_with_readability( - self, - html: str, - url: str, - status_code: int, - content_type: str, - extract_mode: str, - max_chars: int - ) -> Dict[str, Any]: - """Extract content using Mozilla's Readability""" - try: - from readability import Document - - # Parse with Readability - doc = Document(html) - title = doc.title() - content_html = doc.summary() - - # Convert to markdown or text - if extract_mode == "markdown": - text = self._html_to_markdown(content_html) - else: - text = self._html_to_text(content_html) - - # Truncate if needed - truncated = len(text) > max_chars - if truncated: - text = text[:max_chars] - - return { - "url": url, - "status": status_code, - "content_type": content_type, - "title": title, - "extractor": "readability", - "extract_mode": extract_mode, - "text": text, - "length": len(text), - "truncated": truncated - } - - except Exception as e: - logger.warning(f"Readability extraction failed: {e}") - # Fallback to basic extraction - return self._extract_basic( - html, url, status_code, content_type, extract_mode, max_chars - ) - - def _extract_basic( - self, - html: str, - url: str, - status_code: int, - content_type: str, - extract_mode: str, - max_chars: int - ) -> Dict[str, Any]: - """Basic HTML extraction without Readability""" - # Extract title - title_match = re.search(r']*>(.*?)', html, re.IGNORECASE | re.DOTALL) - title = title_match.group(1).strip() if title_match else "Untitled" - - # Remove script and style tags - text = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) - text = re.sub(r']*>.*?', '', text, flags=re.DOTALL | re.IGNORECASE) - - # Remove HTML tags - text = re.sub(r'<[^>]+>', ' ', text) - - # Clean up whitespace - text = re.sub(r'\s+', ' ', text) - text = text.strip() - - # Truncate if needed - truncated = len(text) > max_chars - if truncated: - text = text[:max_chars] - - return { - "url": url, - "status": status_code, - "content_type": content_type, - "title": title, - "extractor": "basic", - "extract_mode": extract_mode, - "text": text, - "length": len(text), - "truncated": truncated, - "warning": "Using basic extraction. Install readability-lxml for better results." - } - - def _html_to_markdown(self, html: str) -> str: - """Convert HTML to Markdown (basic implementation)""" - try: - # Try to use html2text if available - import html2text - h = html2text.HTML2Text() - h.ignore_links = False - h.ignore_images = False - h.body_width = 0 # Don't wrap lines - return h.handle(html) - except ImportError: - # Fallback to basic conversion - return self._html_to_text(html) - - def _html_to_text(self, html: str) -> str: - """Convert HTML to plain text""" - # Remove script and style tags - text = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) - text = re.sub(r']*>.*?', '', text, flags=re.DOTALL | re.IGNORECASE) - - # Convert common tags to text equivalents - text = re.sub(r'', '\n', text, flags=re.IGNORECASE) - text = re.sub(r']*>', '\n\n', text, flags=re.IGNORECASE) - text = re.sub(r'

', '', text, flags=re.IGNORECASE) - text = re.sub(r']*>', '\n\n', text, flags=re.IGNORECASE) - text = re.sub(r'', '\n', text, flags=re.IGNORECASE) - - # Remove all other HTML tags - text = re.sub(r'<[^>]+>', '', text) - - # Decode HTML entities - import html - text = html.unescape(text) - - # Clean up whitespace - text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text) - text = re.sub(r' +', ' ', text) - text = text.strip() - - return text - - def close(self): - """Close the session""" - if hasattr(self, 'session'): - self.session.close() diff --git a/agent/tools/write/write.py b/agent/tools/write/write.py index 9836564..49e01c8 100644 --- a/agent/tools/write/write.py +++ b/agent/tools/write/write.py @@ -14,7 +14,7 @@ class Write(BaseTool): """Tool for writing file content""" name: str = "write" - description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories." + description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories. IMPORTANT: Single write should not exceed 10KB. For large files, create a skeleton first, then use edit to add content in chunks." params: dict = { "type": "object", diff --git a/app.py b/app.py index 022fdd7..ad0cfca 100644 --- a/app.py +++ b/app.py @@ -59,6 +59,23 @@ def run(): os.environ["WECHATY_LOG"] = "warn" start_channel(channel_name) + + # 打印系统运行成功信息 + logger.info("") + logger.info("=" * 50) + if conf().get("agent", False): + logger.info("✅ System started successfully!") + logger.info("🐮 Cow Agent is running") + logger.info(f" Channel: {channel_name}") + logger.info(f" Model: {conf().get('model', 'unknown')}") + logger.info(f" Workspace: {conf().get('agent_workspace', '~/cow')}") + else: + logger.info("✅ System started successfully!") + logger.info("🤖 ChatBot is running") + logger.info(f" Channel: {channel_name}") + logger.info(f" Model: {conf().get('model', 'unknown')}") + logger.info("=" * 50) + logger.info("") while True: time.sleep(1) diff --git a/bridge/agent_bridge.py b/bridge/agent_bridge.py index 34535bb..43aefc7 100644 --- a/bridge/agent_bridge.py +++ b/bridge/agent_bridge.py @@ -2,10 +2,11 @@ Agent Bridge - Integrates Agent system with existing COW bridge """ +import os from typing import Optional, List from agent.protocol import Agent, LLMModel, LLMRequest -from bot.openai_compatible_bot import OpenAICompatibleBot +from models.openai_compatible_bot import OpenAICompatibleBot from bridge.bridge import Bridge from bridge.context import Context from bridge.reply import Reply, ReplyType @@ -180,6 +181,7 @@ class AgentBridge: self.agents = {} # session_id -> Agent instance mapping self.default_agent = None # For backward compatibility (no session_id) self.agent: Optional[Agent] = None + self.scheduler_initialized = False def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent: """ Create the super agent with COW integration @@ -228,12 +230,7 @@ class AgentBridge: # Log skill loading details if agent.skill_manager: - logger.info(f"[AgentBridge] SkillManager initialized:") - logger.info(f"[AgentBridge] - Managed dir: {agent.skill_manager.managed_skills_dir}") - logger.info(f"[AgentBridge] - Workspace dir: {agent.skill_manager.workspace_dir}") - logger.info(f"[AgentBridge] - Total skills: {len(agent.skill_manager.skills)}") - for skill_name in agent.skill_manager.skills.keys(): - logger.info(f"[AgentBridge] * {skill_name}") + logger.debug(f"[AgentBridge] SkillManager initialized with {len(agent.skill_manager.skills)} skills") return agent @@ -268,6 +265,21 @@ class AgentBridge: # Get workspace from config workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow")) + # Migrate API keys from config.json to environment variables (if not already set) + self._migrate_config_to_env(workspace_root) + + # Load environment variables from secure .env file location + env_file = os.path.expanduser("~/.cow/.env") + if os.path.exists(env_file): + try: + from dotenv import load_dotenv + load_dotenv(env_file, override=True) + logger.info(f"[AgentBridge] Loaded environment variables from {env_file}") + except ImportError: + logger.warning("[AgentBridge] python-dotenv not installed, skipping .env file loading") + except Exception as e: + logger.warning(f"[AgentBridge] Failed to load .env file: {e}") + # Initialize workspace and create template files from agent.prompt import ensure_workspace, load_context_files, PromptBuilder @@ -283,14 +295,53 @@ class AgentBridge: from agent.memory import MemoryManager, MemoryConfig from agent.tools import MemorySearchTool, MemoryGetTool - memory_config = MemoryConfig( - workspace_root=workspace_root, - embedding_provider="local", # Use local embedding (no API key needed) - embedding_model="all-MiniLM-L6-v2" - ) + # 从 config.json 读取 OpenAI 配置 + openai_api_key = conf().get("open_ai_api_key", "") + openai_api_base = conf().get("open_ai_api_base", "") - # Create memory manager with the config - memory_manager = MemoryManager(memory_config) + # 尝试初始化 OpenAI embedding provider + embedding_provider = None + if openai_api_key: + try: + from agent.memory import create_embedding_provider + embedding_provider = create_embedding_provider( + provider="openai", + model="text-embedding-3-small", + api_key=openai_api_key, + api_base=openai_api_base or "https://api.openai.com/v1" + ) + logger.info(f"[AgentBridge] OpenAI embedding initialized") + except Exception as embed_error: + logger.warning(f"[AgentBridge] OpenAI embedding failed: {embed_error}") + logger.info(f"[AgentBridge] Using keyword-only search") + else: + logger.info(f"[AgentBridge] No OpenAI API key, using keyword-only search") + + # 创建 memory config + memory_config = MemoryConfig(workspace_root=workspace_root) + + # 创建 memory manager + memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider) + + # 初始化时执行一次 sync,确保数据库有数据 + import asyncio + try: + # 尝试在当前事件循环中执行 + loop = asyncio.get_event_loop() + if loop.is_running(): + # 如果事件循环正在运行,创建任务 + asyncio.create_task(memory_manager.sync()) + logger.info("[AgentBridge] Memory sync scheduled") + else: + # 如果没有运行的循环,直接执行 + loop.run_until_complete(memory_manager.sync()) + logger.info("[AgentBridge] Memory synced successfully") + except RuntimeError: + # 没有事件循环,创建新的 + asyncio.run(memory_manager.sync()) + logger.info("[AgentBridge] Memory synced successfully") + except Exception as e: + logger.warning(f"[AgentBridge] Memory sync failed: {e}") # Create memory tools memory_tools = [ @@ -318,7 +369,15 @@ class AgentBridge: for tool_name in tool_manager.tool_classes.keys(): try: - tool = tool_manager.create_tool(tool_name) + # Special handling for EnvConfig tool - pass agent_bridge reference + if tool_name == "env_config": + from agent.tools import EnvConfig + tool = EnvConfig({ + "agent_bridge": self # Pass self reference for hot reload + }) + else: + tool = tool_manager.create_tool(tool_name) + if tool: # Apply workspace config to file operation tools if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls']: @@ -326,12 +385,6 @@ class AgentBridge: tool.cwd = file_config.get("cwd", tool.cwd if hasattr(tool, 'cwd') else None) if 'memory_manager' in file_config: tool.memory_manager = file_config['memory_manager'] - # Apply API key for bocha_search tool - elif tool_name == 'bocha_search': - bocha_api_key = conf().get("bocha_api_key", "") - if bocha_api_key: - tool.config = {"bocha_api_key": bocha_api_key} - tool.api_key = bocha_api_key tools.append(tool) logger.debug(f"[AgentBridge] Loaded tool: {tool_name}") except Exception as e: @@ -342,6 +395,36 @@ class AgentBridge: tools.extend(memory_tools) logger.info(f"[AgentBridge] Added {len(memory_tools)} memory tools") + # Initialize scheduler service (once) + if not self.scheduler_initialized: + try: + from agent.tools.scheduler.integration import init_scheduler + if init_scheduler(self): + self.scheduler_initialized = True + logger.info("[AgentBridge] Scheduler service initialized") + except Exception as e: + logger.warning(f"[AgentBridge] Failed to initialize scheduler: {e}") + + # Inject scheduler dependencies into SchedulerTool instances + if self.scheduler_initialized: + try: + from agent.tools.scheduler.integration import get_task_store, get_scheduler_service + from agent.tools import SchedulerTool + + task_store = get_task_store() + scheduler_service = get_scheduler_service() + + for tool in tools: + if isinstance(tool, SchedulerTool): + tool.task_store = task_store + tool.scheduler_service = scheduler_service + if not tool.config: + tool.config = {} + tool.config["channel_type"] = conf().get("channel_type", "unknown") + logger.debug("[AgentBridge] Injected scheduler dependencies into SchedulerTool") + except Exception as e: + logger.warning(f"[AgentBridge] Failed to inject scheduler dependencies: {e}") + logger.info(f"[AgentBridge] Loaded {len(tools)} tools: {[t.name for t in tools]}") # Load context files (SOUL.md, USER.md, etc.) @@ -381,14 +464,19 @@ class AgentBridge: logger.info("[AgentBridge] System prompt built successfully") + # Get cost control parameters from config + max_steps = conf().get("agent_max_steps", 20) + max_context_tokens = conf().get("agent_max_context_tokens", 50000) + # Create agent with configured tools and workspace agent = self.create_agent( system_prompt=system_prompt, tools=tools, - max_steps=50, + max_steps=max_steps, output_mode="logger", workspace_dir=workspace_root, # Pass workspace to agent for skills loading - enable_skills=True # Enable skills auto-loading + enable_skills=True, # Enable skills auto-loading + max_context_tokens=max_context_tokens ) # Attach memory manager to agent if available @@ -410,6 +498,24 @@ class AgentBridge: # Get workspace from config workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow")) + # Migrate API keys from config.json to environment variables (if not already set) + self._migrate_config_to_env(workspace_root) + + # Load environment variables from secure .env file location + env_file = os.path.expanduser("~/.cow/.env") + if os.path.exists(env_file): + try: + from dotenv import load_dotenv + load_dotenv(env_file, override=True) + logger.debug(f"[AgentBridge] Loaded environment variables from {env_file} for session {session_id}") + except ImportError: + logger.warning(f"[AgentBridge] python-dotenv not installed, skipping .env file loading for session {session_id}") + except Exception as e: + logger.warning(f"[AgentBridge] Failed to load .env file for session {session_id}: {e}") + + # Migrate API keys from config.json to environment variables (if not already set) + self._migrate_config_to_env(workspace_root) + # Initialize workspace from agent.prompt import ensure_workspace, load_context_files, PromptBuilder @@ -420,23 +526,65 @@ class AgentBridge: memory_tools = [] try: - from agent.memory import MemoryManager, MemoryConfig + from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider from agent.tools import MemorySearchTool, MemoryGetTool - memory_config = MemoryConfig( - workspace_root=workspace_root, - embedding_provider="local", - embedding_model="all-MiniLM-L6-v2" - ) + # 从 config.json 读取 OpenAI 配置 + openai_api_key = conf().get("open_ai_api_key", "") + openai_api_base = conf().get("open_ai_api_base", "") + + # 尝试初始化 OpenAI embedding provider + embedding_provider = None + if openai_api_key: + try: + embedding_provider = create_embedding_provider( + provider="openai", + model="text-embedding-3-small", + api_key=openai_api_key, + api_base=openai_api_base or "https://api.openai.com/v1" + ) + logger.debug(f"[AgentBridge] OpenAI embedding initialized for session {session_id}") + except Exception as embed_error: + logger.warning(f"[AgentBridge] OpenAI embedding failed for session {session_id}: {embed_error}") + logger.info(f"[AgentBridge] Using keyword-only search for session {session_id}") + else: + logger.debug(f"[AgentBridge] No OpenAI API key, using keyword-only search for session {session_id}") + + # 创建 memory config + memory_config = MemoryConfig(workspace_root=workspace_root) + + # 创建 memory manager + memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider) + + # 初始化时执行一次 sync,确保数据库有数据 + import asyncio + try: + # 尝试在当前事件循环中执行 + loop = asyncio.get_event_loop() + if loop.is_running(): + # 如果事件循环正在运行,创建任务 + asyncio.create_task(memory_manager.sync()) + logger.debug(f"[AgentBridge] Memory sync scheduled for session {session_id}") + else: + # 如果没有运行的循环,直接执行 + loop.run_until_complete(memory_manager.sync()) + logger.debug(f"[AgentBridge] Memory synced successfully for session {session_id}") + except RuntimeError: + # 没有事件循环,创建新的 + asyncio.run(memory_manager.sync()) + logger.debug(f"[AgentBridge] Memory synced successfully for session {session_id}") + except Exception as sync_error: + logger.warning(f"[AgentBridge] Memory sync failed for session {session_id}: {sync_error}") - memory_manager = MemoryManager(memory_config) memory_tools = [ MemorySearchTool(memory_manager), MemoryGetTool(memory_manager) ] except Exception as e: - logger.debug(f"[AgentBridge] Memory system not available for session {session_id}: {e}") + logger.warning(f"[AgentBridge] Memory system not available for session {session_id}: {e}") + import traceback + logger.warning(f"[AgentBridge] Memory init traceback: {traceback.format_exc()}") # Load tools from agent.tools import ToolManager @@ -458,17 +606,42 @@ class AgentBridge: tool.cwd = file_config.get("cwd", tool.cwd if hasattr(tool, 'cwd') else None) if 'memory_manager' in file_config: tool.memory_manager = file_config['memory_manager'] - elif tool_name == 'bocha_search': - bocha_api_key = conf().get("bocha_api_key", "") - if bocha_api_key: - tool.config = {"bocha_api_key": bocha_api_key} - tool.api_key = bocha_api_key tools.append(tool) except Exception as e: logger.warning(f"[AgentBridge] Failed to load tool {tool_name} for session {session_id}: {e}") if memory_tools: tools.extend(memory_tools) + + # Initialize scheduler service (once, if not already initialized) + if not self.scheduler_initialized: + try: + from agent.tools.scheduler.integration import init_scheduler + if init_scheduler(self): + self.scheduler_initialized = True + logger.debug(f"[AgentBridge] Scheduler service initialized for session {session_id}") + except Exception as e: + logger.warning(f"[AgentBridge] Failed to initialize scheduler for session {session_id}: {e}") + + # Inject scheduler dependencies into SchedulerTool instances + if self.scheduler_initialized: + try: + from agent.tools.scheduler.integration import get_task_store, get_scheduler_service + from agent.tools import SchedulerTool + + task_store = get_task_store() + scheduler_service = get_scheduler_service() + + for tool in tools: + if isinstance(tool, SchedulerTool): + tool.task_store = task_store + tool.scheduler_service = scheduler_service + if not tool.config: + tool.config = {} + tool.config["channel_type"] = conf().get("channel_type", "unknown") + logger.debug(f"[AgentBridge] Injected scheduler dependencies for session {session_id}") + except Exception as e: + logger.warning(f"[AgentBridge] Failed to inject scheduler dependencies for session {session_id}: {e}") # Load context files context_files = load_context_files(workspace_root) @@ -478,7 +651,7 @@ class AgentBridge: try: from agent.skills import SkillManager skill_manager = SkillManager(workspace_dir=workspace_root) - logger.info(f"[AgentBridge] Initialized SkillManager with {len(skill_manager.skills)} skills for session {session_id}") + logger.debug(f"[AgentBridge] Initialized SkillManager with {len(skill_manager.skills)} skills for session {session_id}") except Exception as e: logger.warning(f"[AgentBridge] Failed to initialize SkillManager for session {session_id}: {e}") @@ -543,15 +716,20 @@ class AgentBridge: if is_first: mark_conversation_started(workspace_root) + # Get cost control parameters from config + max_steps = conf().get("agent_max_steps", 20) + max_context_tokens = conf().get("agent_max_context_tokens", 50000) + # Create agent for this session agent = self.create_agent( system_prompt=system_prompt, tools=tools, - max_steps=50, + max_steps=max_steps, output_mode="logger", workspace_dir=workspace_root, skill_manager=skill_manager, - enable_skills=True + enable_skills=True, + max_context_tokens=max_context_tokens ) if memory_manager: @@ -586,12 +764,52 @@ class AgentBridge: if not agent: return Reply(ReplyType.ERROR, "Failed to initialize super agent") - # Use agent's run_stream method - response = agent.run_stream( - user_message=query, - on_event=on_event, - clear_history=clear_history - ) + # Filter tools based on context + original_tools = agent.tools + filtered_tools = original_tools + + # If this is a scheduled task execution, exclude scheduler tool to prevent recursion + if context and context.get("is_scheduled_task"): + filtered_tools = [tool for tool in agent.tools if tool.name != "scheduler"] + agent.tools = filtered_tools + logger.info(f"[AgentBridge] Scheduled task execution: excluded scheduler tool ({len(filtered_tools)}/{len(original_tools)} tools)") + else: + # Attach context to scheduler tool if present + if context and agent.tools: + for tool in agent.tools: + if tool.name == "scheduler": + try: + from agent.tools.scheduler.integration import attach_scheduler_to_tool + attach_scheduler_to_tool(tool, context) + except Exception as e: + logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}") + break + + try: + # Use agent's run_stream method + response = agent.run_stream( + user_message=query, + on_event=on_event, + clear_history=clear_history + ) + finally: + # Restore original tools + if context and context.get("is_scheduled_task"): + agent.tools = original_tools + + # Check if there are files to send (from read tool) + if hasattr(agent, 'stream_executor') and hasattr(agent.stream_executor, 'files_to_send'): + files_to_send = agent.stream_executor.files_to_send + if files_to_send: + # Send the first file (for now, handle one file at a time) + file_info = files_to_send[0] + logger.info(f"[AgentBridge] Sending file: {file_info.get('path')}") + + # Clear files_to_send for next request + agent.stream_executor.files_to_send = [] + + # Return file reply based on file type + return self._create_file_reply(file_info, response, context) return Reply(ReplyType.TEXT, response) @@ -599,6 +817,120 @@ class AgentBridge: logger.error(f"Agent reply error: {e}") return Reply(ReplyType.ERROR, f"Agent error: {str(e)}") + def _create_file_reply(self, file_info: dict, text_response: str, context: Context = None) -> Reply: + """ + Create a reply for sending files + + Args: + file_info: File metadata from read tool + text_response: Text response from agent + context: Context object + + Returns: + Reply object for file sending + """ + file_type = file_info.get("file_type", "file") + file_path = file_info.get("path") + + # For images, use IMAGE_URL type (channel will handle upload) + if file_type == "image": + # Convert local path to file:// URL for channel processing + file_url = f"file://{file_path}" + logger.info(f"[AgentBridge] Sending image: {file_url}") + reply = Reply(ReplyType.IMAGE_URL, file_url) + # Attach text message if present (for channels that support text+image) + if text_response: + reply.text_content = text_response # Store accompanying text + return reply + + # For documents (PDF, Excel, Word, PPT), use FILE type + if file_type == "document": + file_url = f"file://{file_path}" + logger.info(f"[AgentBridge] Sending document: {file_url}") + reply = Reply(ReplyType.FILE, file_url) + reply.file_name = file_info.get("file_name", os.path.basename(file_path)) + return reply + + # For other files (video, audio), we need channel-specific handling + # For now, return text with file info + # TODO: Implement video/audio sending when channel supports it + message = text_response or file_info.get("message", "文件已准备") + message += f"\n\n[文件: {file_info.get('file_name', file_path)}]" + return Reply(ReplyType.TEXT, message) + + def _migrate_config_to_env(self, workspace_root: str): + """ + Migrate API keys from config.json to .env file if not already set + + Args: + workspace_root: Workspace directory path (not used, kept for compatibility) + """ + from config import conf + import os + + # Mapping from config.json keys to environment variable names + key_mapping = { + "open_ai_api_key": "OPENAI_API_KEY", + "open_ai_api_base": "OPENAI_API_BASE", + "gemini_api_key": "GEMINI_API_KEY", + "claude_api_key": "CLAUDE_API_KEY", + "linkai_api_key": "LINKAI_API_KEY", + } + + # Use fixed secure location for .env file + env_file = os.path.expanduser("~/.cow/.env") + + # Read existing env vars from .env file + existing_env_vars = {} + if os.path.exists(env_file): + try: + with open(env_file, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line and not line.startswith('#') and '=' in line: + key, _ = line.split('=', 1) + existing_env_vars[key.strip()] = True + except Exception as e: + logger.warning(f"[AgentBridge] Failed to read .env file: {e}") + + # Check which keys need to be migrated + keys_to_migrate = {} + for config_key, env_key in key_mapping.items(): + # Skip if already in .env file + if env_key in existing_env_vars: + continue + + # Get value from config.json + value = conf().get(config_key, "") + if value and value.strip(): # Only migrate non-empty values + keys_to_migrate[env_key] = value.strip() + + # Log summary if there are keys to skip + if existing_env_vars: + logger.debug(f"[AgentBridge] {len(existing_env_vars)} env vars already in .env") + + # Write new keys to .env file + if keys_to_migrate: + try: + # Ensure ~/.cow directory and .env file exist + env_dir = os.path.dirname(env_file) + if not os.path.exists(env_dir): + os.makedirs(env_dir, exist_ok=True) + if not os.path.exists(env_file): + open(env_file, 'a').close() + + # Append new keys + with open(env_file, 'a', encoding='utf-8') as f: + f.write('\n# Auto-migrated from config.json\n') + for key, value in keys_to_migrate.items(): + f.write(f'{key}={value}\n') + # Also set in current process + os.environ[key] = value + + logger.info(f"[AgentBridge] Migrated {len(keys_to_migrate)} API keys from config.json to .env: {list(keys_to_migrate.keys())}") + except Exception as e: + logger.warning(f"[AgentBridge] Failed to migrate API keys: {e}") + def clear_session(self, session_id: str): """ Clear a specific session's agent and conversation history @@ -614,4 +946,43 @@ class AgentBridge: """Clear all agent sessions""" logger.info(f"[AgentBridge] Clearing all sessions ({len(self.agents)} total)") self.agents.clear() - self.default_agent = None \ No newline at end of file + self.default_agent = None + + def refresh_all_skills(self) -> int: + """ + Refresh skills in all agent instances after environment variable changes. + This allows hot-reload of skills without restarting the agent. + + Returns: + Number of agent instances refreshed + """ + import os + from dotenv import load_dotenv + from config import conf + + # Reload environment variables from .env file + workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow")) + env_file = os.path.join(workspace_root, '.env') + + if os.path.exists(env_file): + load_dotenv(env_file, override=True) + logger.info(f"[AgentBridge] Reloaded environment variables from {env_file}") + + refreshed_count = 0 + + # Refresh default agent + if self.default_agent and hasattr(self.default_agent, 'skill_manager'): + self.default_agent.skill_manager.refresh_skills() + refreshed_count += 1 + logger.info("[AgentBridge] Refreshed skills in default agent") + + # Refresh all session agents + for session_id, agent in self.agents.items(): + if hasattr(agent, 'skill_manager'): + agent.skill_manager.refresh_skills() + refreshed_count += 1 + + if refreshed_count > 0: + logger.info(f"[AgentBridge] Refreshed skills in {refreshed_count} agent instance(s)") + + return refreshed_count \ No newline at end of file diff --git a/bridge/bridge.py b/bridge/bridge.py index a7b93c4..4c686f9 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -1,4 +1,4 @@ -from bot.bot_factory import create_bot +from models.bot_factory import create_bot from bridge.context import Context from bridge.reply import Reply from common import const diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 1523f67..af3607d 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -64,15 +64,22 @@ class ChatChannel(Channel): check_contain(group_name, group_name_keyword_white_list), ] ): - group_chat_in_one_session = conf().get("group_chat_in_one_session", []) - session_id = cmsg.actual_user_id - if any( - [ - group_name in group_chat_in_one_session, - "ALL_GROUP" in group_chat_in_one_session, - ] - ): + # Check global group_shared_session config first + group_shared_session = conf().get("group_shared_session", True) + if group_shared_session: + # All users in the group share the same session session_id = group_id + else: + # Check group-specific whitelist (legacy behavior) + group_chat_in_one_session = conf().get("group_chat_in_one_session", []) + session_id = cmsg.actual_user_id + if any( + [ + group_name in group_chat_in_one_session, + "ALL_GROUP" in group_chat_in_one_session, + ] + ): + session_id = group_id else: logger.debug(f"No need reply, groupName not in whitelist, group_name={group_name}") return None @@ -166,11 +173,11 @@ class ChatChannel(Channel): def _handle(self, context: Context): if context is None or not context.content: return - logger.debug("[chat_channel] ready to handle context: {}".format(context)) + logger.debug("[chat_channel] handling context: {}".format(context)) # reply的构建步骤 reply = self._generate_reply(context) - logger.debug("[chat_channel] ready to decorate reply: {}".format(reply)) + logger.debug("[chat_channel] decorating reply: {}".format(reply)) # reply的包装步骤 if reply and reply.content: @@ -188,7 +195,7 @@ class ChatChannel(Channel): ) reply = e_context["reply"] if not e_context.is_pass(): - logger.debug("[chat_channel] ready to handle context: type={}, content={}".format(context.type, context.content)) + logger.debug("[chat_channel] type={}, content={}".format(context.type, context.content)) if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 context["channel"] = e_context["channel"] reply = super().build_reply_content(context.content, context) @@ -282,7 +289,100 @@ class ChatChannel(Channel): ) reply = e_context["reply"] if not e_context.is_pass() and reply and reply.type: - logger.debug("[chat_channel] ready to send reply: {}, context: {}".format(reply, context)) + logger.debug("[chat_channel] sending reply: {}, context: {}".format(reply, context)) + + # 如果是文本回复,尝试提取并发送图片 + if reply.type == ReplyType.TEXT: + self._extract_and_send_images(reply, context) + # 如果是图片回复但带有文本内容,先发文本再发图片 + elif reply.type == ReplyType.IMAGE_URL and hasattr(reply, 'text_content') and reply.text_content: + # 先发送文本 + text_reply = Reply(ReplyType.TEXT, reply.text_content) + self._send(text_reply, context) + # 短暂延迟后发送图片 + time.sleep(0.3) + self._send(reply, context) + else: + self._send(reply, context) + + def _extract_and_send_images(self, reply: Reply, context: Context): + """ + 从文本回复中提取图片/视频URL并单独发送 + 支持格式:[图片: /path/to/image.png], [视频: /path/to/video.mp4], ![](url), + 最多发送5个媒体文件 + """ + content = reply.content + media_items = [] # [(url, type), ...] + + # 正则提取各种格式的媒体URL + patterns = [ + (r'\[图片:\s*([^\]]+)\]', 'image'), # [图片: /path/to/image.png] + (r'\[视频:\s*([^\]]+)\]', 'video'), # [视频: /path/to/video.mp4] + (r'!\[.*?\]\(([^\)]+)\)', 'image'), # ![alt](url) - 默认图片 + (r']+src=["\']([^"\']+)["\']', 'image'), # + (r']+src=["\']([^"\']+)["\']', 'video'), #