diff --git a/agent/memory/__init__.py b/agent/memory/__init__.py
index 4179bea..f638a9d 100644
--- a/agent/memory/__init__.py
+++ b/agent/memory/__init__.py
@@ -6,5 +6,6 @@ Provides long-term memory capabilities with hybrid search (vector + keyword)
from agent.memory.manager import MemoryManager
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
+from agent.memory.embedding import create_embedding_provider
-__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config']
+__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config', 'create_embedding_provider']
diff --git a/agent/memory/config.py b/agent/memory/config.py
index 758611d..63d945b 100644
--- a/agent/memory/config.py
+++ b/agent/memory/config.py
@@ -41,6 +41,10 @@ class MemoryConfig:
enable_auto_sync: bool = True
sync_on_search: bool = True
+ # Memory flush config (独立于模型 context window)
+ flush_token_threshold: int = 50000 # 50K tokens 触发 flush
+ flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮)
+
def get_workspace(self) -> Path:
"""Get workspace root directory"""
return Path(self.workspace_root)
diff --git a/agent/memory/embedding.py b/agent/memory/embedding.py
index 4a71828..509370b 100644
--- a/agent/memory/embedding.py
+++ b/agent/memory/embedding.py
@@ -4,20 +4,19 @@ Embedding providers for memory
Supports OpenAI and local embedding models
"""
-from typing import List, Optional
-from abc import ABC, abstractmethod
import hashlib
-import json
+from abc import ABC, abstractmethod
+from typing import List, Optional
class EmbeddingProvider(ABC):
"""Base class for embedding providers"""
-
+
@abstractmethod
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
pass
-
+
@abstractmethod
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
@@ -31,7 +30,7 @@ class EmbeddingProvider(ABC):
class OpenAIEmbeddingProvider(EmbeddingProvider):
- """OpenAI embedding provider"""
+ """OpenAI embedding provider using REST API"""
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
"""
@@ -45,87 +44,58 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
self.model = model
self.api_key = api_key
self.api_base = api_base or "https://api.openai.com/v1"
-
- # Lazy import to avoid dependency issues
- try:
- from openai import OpenAI
- self.client = OpenAI(api_key=api_key, base_url=api_base)
- except ImportError:
- raise ImportError("OpenAI package not installed. Install with: pip install openai")
-
+
+ if not self.api_key:
+ raise ValueError("OpenAI API key is required")
+
# Set dimensions based on model
self._dimensions = 1536 if "small" in model else 3072
-
+
+ def _call_api(self, input_data):
+ """Call OpenAI embedding API using requests"""
+ import requests
+
+ url = f"{self.api_base}/embeddings"
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {self.api_key}"
+ }
+ data = {
+ "input": input_data,
+ "model": self.model
+ }
+
+ response = requests.post(url, headers=headers, json=data, timeout=30)
+ response.raise_for_status()
+ return response.json()
+
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
- response = self.client.embeddings.create(
- input=text,
- model=self.model
- )
- return response.data[0].embedding
-
+ result = self._call_api(text)
+ return result["data"][0]["embedding"]
+
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
if not texts:
return []
-
- response = self.client.embeddings.create(
- input=texts,
- model=self.model
- )
- return [item.embedding for item in response.data]
-
+
+ result = self._call_api(texts)
+ return [item["embedding"] for item in result["data"]]
+
@property
def dimensions(self) -> int:
return self._dimensions
-class LocalEmbeddingProvider(EmbeddingProvider):
- """Local embedding provider using sentence-transformers"""
-
- def __init__(self, model: str = "all-MiniLM-L6-v2"):
- """
- Initialize local embedding provider
-
- Args:
- model: Model name from sentence-transformers
- """
- self.model_name = model
-
- try:
- from sentence_transformers import SentenceTransformer
- self.model = SentenceTransformer(model)
- self._dimensions = self.model.get_sentence_embedding_dimension()
- except ImportError:
- raise ImportError(
- "sentence-transformers not installed. "
- "Install with: pip install sentence-transformers"
- )
-
- def embed(self, text: str) -> List[float]:
- """Generate embedding for text"""
- embedding = self.model.encode(text, convert_to_numpy=True)
- return embedding.tolist()
-
- def embed_batch(self, texts: List[str]) -> List[List[float]]:
- """Generate embeddings for multiple texts"""
- if not texts:
- return []
-
- embeddings = self.model.encode(texts, convert_to_numpy=True)
- return embeddings.tolist()
-
- @property
- def dimensions(self) -> int:
- return self._dimensions
+# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
class EmbeddingCache:
"""Cache for embeddings to avoid recomputation"""
-
+
def __init__(self):
self.cache = {}
-
+
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
"""Get cached embedding"""
key = self._compute_key(text, provider, model)
@@ -156,20 +126,23 @@ def create_embedding_provider(
"""
Factory function to create embedding provider
+ Only supports OpenAI embedding via REST API.
+ If initialization fails, caller should fall back to keyword-only search.
+
Args:
- provider: Provider name ("openai" or "local")
- model: Model name (provider-specific)
- api_key: API key for remote providers
- api_base: API base URL for remote providers
+ provider: Provider name (only "openai" is supported)
+ model: Model name (default: text-embedding-3-small)
+ api_key: OpenAI API key (required)
+ api_base: API base URL (default: https://api.openai.com/v1)
Returns:
EmbeddingProvider instance
+
+ Raises:
+ ValueError: If provider is not "openai" or api_key is missing
"""
- if provider == "openai":
- model = model or "text-embedding-3-small"
- return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
- elif provider == "local":
- model = model or "all-MiniLM-L6-v2"
- return LocalEmbeddingProvider(model=model)
- else:
- raise ValueError(f"Unknown embedding provider: {provider}")
+ if provider != "openai":
+ raise ValueError(f"Only 'openai' provider is supported, got: {provider}")
+
+ model = model or "text-embedding-3-small"
+ return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
diff --git a/agent/memory/manager.py b/agent/memory/manager.py
index b313e82..1e47811 100644
--- a/agent/memory/manager.py
+++ b/agent/memory/manager.py
@@ -70,8 +70,9 @@ class MemoryManager:
except Exception as e:
# Embedding provider failed, but that's OK
# We can still use keyword search and file operations
- print(f"⚠️ Warning: Embedding provider initialization failed: {e}")
- print(f"ℹ️ Memory will work with keyword search only (no semantic search)")
+ from common.log import logger
+ logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}")
+ logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
# Initialize memory flush manager
workspace_dir = self.config.get_workspace()
@@ -135,13 +136,19 @@ class MemoryManager:
# Perform vector search (if embedding provider available)
vector_results = []
if self.embedding_provider:
- query_embedding = self.embedding_provider.embed(query)
- vector_results = self.storage.search_vector(
- query_embedding=query_embedding,
- user_id=user_id,
- scopes=scopes,
- limit=max_results * 2 # Get more candidates for merging
- )
+ try:
+ from common.log import logger
+ query_embedding = self.embedding_provider.embed(query)
+ vector_results = self.storage.search_vector(
+ query_embedding=query_embedding,
+ user_id=user_id,
+ scopes=scopes,
+ limit=max_results * 2 # Get more candidates for merging
+ )
+ logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
+ except Exception as e:
+ from common.log import logger
+ logger.warning(f"[MemoryManager] Vector search failed: {e}")
# Perform keyword search
keyword_results = self.storage.search_keyword(
@@ -150,6 +157,8 @@ class MemoryManager:
scopes=scopes,
limit=max_results * 2
)
+ from common.log import logger
+ logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
# Merge results
merged = self._merge_results(
@@ -356,30 +365,30 @@ class MemoryManager:
def should_flush_memory(
self,
- current_tokens: int,
- context_window: int = 128000,
- reserve_tokens: int = 20000,
- soft_threshold: int = 4000
+ current_tokens: int = 0
) -> bool:
"""
Check if memory flush should be triggered
+ 独立的 flush 触发机制,不依赖模型 context window。
+ 使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold
+
Args:
current_tokens: Current session token count
- context_window: Model's context window size (default: 128K)
- reserve_tokens: Reserve tokens for compaction overhead (default: 20K)
- soft_threshold: Trigger N tokens before threshold (default: 4K)
Returns:
True if memory flush should run
"""
return self.flush_manager.should_flush(
current_tokens=current_tokens,
- context_window=context_window,
- reserve_tokens=reserve_tokens,
- soft_threshold=soft_threshold
+ token_threshold=self.config.flush_token_threshold,
+ turn_threshold=self.config.flush_turn_threshold
)
+ def increment_turn(self):
+ """增加对话轮数计数(每次用户消息+AI回复算一轮)"""
+ self.flush_manager.increment_turn()
+
async def execute_memory_flush(
self,
agent_executor,
diff --git a/agent/memory/storage.py b/agent/memory/storage.py
index b8fccf0..373512b 100644
--- a/agent/memory/storage.py
+++ b/agent/memory/storage.py
@@ -46,14 +46,32 @@ class MemoryStorage:
def __init__(self, db_path: Path):
self.db_path = db_path
self.conn: Optional[sqlite3.Connection] = None
+ self.fts5_available = False # Track FTS5 availability
self._init_db()
+ def _check_fts5_support(self) -> bool:
+ """Check if SQLite has FTS5 support"""
+ try:
+ self.conn.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(test)")
+ self.conn.execute("DROP TABLE IF EXISTS fts5_test")
+ return True
+ except sqlite3.OperationalError as e:
+ if "no such module: fts5" in str(e):
+ return False
+ raise
+
def _init_db(self):
"""Initialize database with schema"""
try:
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
+ # Check FTS5 support
+ self.fts5_available = self._check_fts5_support()
+ if not self.fts5_available:
+ from common.log import logger
+ logger.warning("[MemoryStorage] FTS5 not available, using LIKE-based keyword search")
+
# Check database integrity
try:
result = self.conn.execute("PRAGMA integrity_check").fetchone()
@@ -125,43 +143,44 @@ class MemoryStorage:
ON chunks(path, hash)
""")
- # Create FTS5 virtual table for keyword search
- # Use default unicode61 tokenizer (stable and compatible)
- # For CJK support, we'll use LIKE queries as fallback
- self.conn.execute("""
- CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
- text,
- id UNINDEXED,
- user_id UNINDEXED,
- path UNINDEXED,
- source UNINDEXED,
- scope UNINDEXED,
- content='chunks',
- content_rowid='rowid'
- )
- """)
-
- # Create triggers to keep FTS in sync
- self.conn.execute("""
- CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
- INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
- VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
- END
- """)
-
- self.conn.execute("""
- CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
- DELETE FROM chunks_fts WHERE rowid = old.rowid;
- END
- """)
-
- self.conn.execute("""
- CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
- UPDATE chunks_fts SET text = new.text, id = new.id,
- user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
- WHERE rowid = new.rowid;
- END
- """)
+ # Create FTS5 virtual table for keyword search (only if supported)
+ if self.fts5_available:
+ # Use default unicode61 tokenizer (stable and compatible)
+ # For CJK support, we'll use LIKE queries as fallback
+ self.conn.execute("""
+ CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
+ text,
+ id UNINDEXED,
+ user_id UNINDEXED,
+ path UNINDEXED,
+ source UNINDEXED,
+ scope UNINDEXED,
+ content='chunks',
+ content_rowid='rowid'
+ )
+ """)
+
+ # Create triggers to keep FTS in sync
+ self.conn.execute("""
+ CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
+ INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
+ VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
+ END
+ """)
+
+ self.conn.execute("""
+ CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
+ DELETE FROM chunks_fts WHERE rowid = old.rowid;
+ END
+ """)
+
+ self.conn.execute("""
+ CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
+ UPDATE chunks_fts SET text = new.text, id = new.id,
+ user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
+ WHERE rowid = new.rowid;
+ END
+ """)
# Create files metadata table
self.conn.execute("""
@@ -301,21 +320,22 @@ class MemoryStorage:
Keyword search using FTS5 + LIKE fallback
Strategy:
- 1. Try FTS5 search first (good for English and word-based languages)
- 2. If no results and query contains CJK characters, use LIKE search
+ 1. If FTS5 available: Try FTS5 search first (good for English and word-based languages)
+ 2. If no FTS5 or no results and query contains CJK: Use LIKE search
"""
if scopes is None:
scopes = ["shared"]
if user_id:
scopes.append("user")
- # Try FTS5 search first
- fts_results = self._search_fts5(query, user_id, scopes, limit)
- if fts_results:
- return fts_results
+ # Try FTS5 search first (if available)
+ if self.fts5_available:
+ fts_results = self._search_fts5(query, user_id, scopes, limit)
+ if fts_results:
+ return fts_results
- # Fallback to LIKE search for CJK characters
- if MemoryStorage._contains_cjk(query):
+ # Fallback to LIKE search (always for CJK, or if FTS5 not available)
+ if not self.fts5_available or MemoryStorage._contains_cjk(query):
return self._search_like(query, user_id, scopes, limit)
return []
diff --git a/agent/memory/summarizer.py b/agent/memory/summarizer.py
index 4b102ab..46b2b59 100644
--- a/agent/memory/summarizer.py
+++ b/agent/memory/summarizer.py
@@ -41,46 +41,42 @@ class MemoryFlushManager:
# Tracking
self.last_flush_token_count: Optional[int] = None
self.last_flush_timestamp: Optional[datetime] = None
+ self.turn_count: int = 0 # 对话轮数计数器
def should_flush(
self,
- current_tokens: int,
- context_window: int,
- reserve_tokens: int = 20000,
- soft_threshold: int = 4000
+ current_tokens: int = 0,
+ token_threshold: int = 50000,
+ turn_threshold: int = 20
) -> bool:
"""
Determine if memory flush should be triggered
- Similar to clawdbot's shouldRunMemoryFlush logic:
- threshold = contextWindow - reserveTokens - softThreshold
+ 独立的 flush 触发机制,不依赖模型 context window:
+ - Token 阈值: 达到 50K tokens 时触发
+ - 轮次阈值: 达到 20 轮对话时触发
Args:
current_tokens: Current session token count
- context_window: Model's context window size
- reserve_tokens: Reserve tokens for compaction overhead
- soft_threshold: Trigger flush N tokens before threshold
+ token_threshold: Token threshold to trigger flush (default: 50K)
+ turn_threshold: Turn threshold to trigger flush (default: 20)
Returns:
True if flush should run
"""
- if current_tokens <= 0:
- return False
+ # 检查 token 阈值
+ if current_tokens > 0 and current_tokens >= token_threshold:
+ # 避免重复 flush
+ if self.last_flush_token_count is not None:
+ if current_tokens <= self.last_flush_token_count + 5000:
+ return False
+ return True
- threshold = max(0, context_window - reserve_tokens - soft_threshold)
- if threshold <= 0:
- return False
+ # 检查轮次阈值
+ if self.turn_count >= turn_threshold:
+ return True
- # Check if we've crossed the threshold
- if current_tokens < threshold:
- return False
-
- # Avoid duplicate flush in same compaction cycle
- if self.last_flush_token_count is not None:
- if current_tokens <= self.last_flush_token_count + soft_threshold:
- return False
-
- return True
+ return False
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
"""
@@ -130,7 +126,12 @@ class MemoryFlushManager:
f"Pre-compaction memory flush. "
f"Store durable memories now (use memory/{today}.md for daily notes; "
f"create memory/ if needed). "
- f"If nothing to store, reply with NO_REPLY."
+ f"\n\n"
+ f"重要提示:\n"
+ f"- MEMORY.md: 记录最核心、最常用的信息(例如重要规则、偏好、决策、要求等)\n"
+ f" 如果 MEMORY.md 过长,可以精简或移除不再重要的内容。避免冗长描述,用关键词和要点形式记录\n"
+ f"- memory/{today}.md: 记录当天发生的事件、关键信息、经验教训、对话过程摘要等,突出重点\n"
+ f"- 如果没有重要内容需要记录,回复 NO_REPLY\n"
)
def create_flush_system_prompt(self) -> str:
@@ -142,6 +143,20 @@ class MemoryFlushManager:
return (
"Pre-compaction memory flush turn. "
"The session is near auto-compaction; capture durable memories to disk. "
+ "\n\n"
+ "记忆写入原则:\n"
+ "1. MEMORY.md 精简原则: 只记录核心信息(<2000 tokens)\n"
+ " - 记录重要规则、偏好、决策、要求等需要长期记住的关键信息,无需记录过多细节\n"
+ " - 如果 MEMORY.md 过长,可以根据需要精简或删除过时内容\n"
+ "\n"
+ "2. 天级记忆 (memory/YYYY-MM-DD.md):\n"
+ " - 记录当天的重要事件、关键信息、经验教训、对话过程摘要等,确保核心信息点被完整记录\n"
+ "\n"
+ "3. 判断标准:\n"
+ " - 这个信息未来会经常用到吗?→ MEMORY.md\n"
+ " - 这是今天的重要事件或决策吗?→ memory/YYYY-MM-DD.md\n"
+ " - 这是临时性的、不重要的内容吗?→ 不记录\n"
+ "\n"
"You may reply, but usually NO_REPLY is correct."
)
@@ -180,6 +195,7 @@ class MemoryFlushManager:
# Track flush
self.last_flush_token_count = current_tokens
self.last_flush_timestamp = datetime.now()
+ self.turn_count = 0 # 重置轮数计数器
return True
@@ -187,6 +203,10 @@ class MemoryFlushManager:
print(f"Memory flush failed: {e}")
return False
+ def increment_turn(self):
+ """增加对话轮数计数"""
+ self.turn_count += 1
+
def get_status(self) -> dict:
"""Get memory flush status"""
return {
diff --git a/agent/prompt/builder.py b/agent/prompt/builder.py
index 6d8fc80..c9ae35b 100644
--- a/agent/prompt/builder.py
+++ b/agent/prompt/builder.py
@@ -179,8 +179,8 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
tool_map = {}
tool_descriptions = {
"read": "读取文件内容",
- "write": "创建或覆盖文件",
- "edit": "精确编辑文件内容",
+ "write": "创建新文件或完全覆盖现有文件(会删除原内容!追加内容请用 edit)。注意:单次 write 内容不要超过 10KB,超大文件请分步创建",
+ "edit": "精确编辑文件(追加、修改、删除部分内容)",
"ls": "列出目录内容",
"grep": "在文件中搜索内容",
"find": "按照模式查找文件",
@@ -237,11 +237,13 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
"叙述要求: 保持简洁、信息密度高,避免重复显而易见的步骤。",
"",
"完成标准:",
- "- 确保用户的需求得到实际解决,而不仅仅是制定计划",
- "- 当任务需要多次工具调用时,持续推进直到完成",
+ "- 确保用户的需求得到实际解决,而不仅仅是制定计划。",
+ "- 当任务需要多次工具调用时,持续推进直到完成, 解决完后向用户报告结果或回复用户的问题",
"- 每次工具调用后,评估是否已获得足够信息来推进或完成任务",
"- 避免重复调用相同的工具和相同参数获取相同的信息,除非用户明确要求",
"",
+ "**安全提醒**: 回复中涉及密钥、令牌、密码等敏感信息时,必须脱敏处理,禁止直接显示完整内容。",
+ "",
])
return lines
@@ -305,17 +307,21 @@ def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], langu
"",
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
"",
- "1. 使用 `memory_search` 在 MEMORY.md 和 memory/*.md 中搜索",
- "2. 然后使用 `memory_get` 只拉取需要的行",
- "3. 如果搜索后仍然信心不足,告诉用户你已经检查过了",
+ "1. 不确定记忆文件位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
+ "2. 已知文件位置 → 直接用 `memory_get` 读取相应的行 (例如:MEMORY.md, memory/YYYY-MM-DD.md)",
+ "3. search 无结果 → 尝试用 `memory_get` 读取MEMORY.md及最近两天记忆文件",
"",
"**记忆文件结构**:",
- "- `MEMORY.md`: 长期记忆,包含重要的背景信息",
- "- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的对话和事件",
+ "- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)",
+ "- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息",
"",
- "**使用原则**:",
- "- 自然使用记忆,就像你本来就知道",
- "- 不要主动提起或列举记忆,除非用户明确询问",
+ "**写入记忆**:",
+ "- 追加内容 → `edit` 工具,oldText 留空",
+ "- 修改内容 → `edit` 工具,oldText 填写要替换的文本",
+ "- 新建文件 → `write` 工具",
+ "- **禁止写入敏感信息**:API密钥、令牌等敏感信息严禁写入记忆文件",
+ "",
+ "**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
"",
]
@@ -385,8 +391,8 @@ def _build_workspace_section(workspace_dir: str, language: str, is_first_convers
"",
"**交流规范**:",
"",
- "- 在所有对话中,无需提及技术细节(如 SOUL.md、USER.md 等文件名,工具名称,配置等),除非用户明确询问",
- "- 用自然表达如「我已记住」而非「已更新 SOUL.md」",
+ "- 在对话中,非必要不输出工作空间技术细节(如 SOUL.md、USER.md等文件名称,工具名称,配置等),除非用户明确询问",
+ "- 例如用自然表达如「我已记住」而非「已更新 MEMORY.md」",
"",
]
diff --git a/agent/prompt/workspace.py b/agent/prompt/workspace.py
index a096706..37d9a96 100644
--- a/agent/prompt/workspace.py
+++ b/agent/prompt/workspace.py
@@ -64,7 +64,7 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works
_create_template_if_missing(agents_path, _get_agents_template())
_create_template_if_missing(memory_path, _get_memory_template())
- logger.info(f"[Workspace] Initialized workspace at: {workspace_dir}")
+ logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}")
return WorkspaceFiles(
soul_path=soul_path,
@@ -270,14 +270,9 @@ def _get_agents_template() -> str:
2. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项)
3. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
-**重要**:
-- 爱好(唱歌、篮球等)→ MEMORY.md,不是 USER.md
-- 近期计划(下周要做什么)→ MEMORY.md,不是 USER.md
-- USER.md 只存放不会变的基本信息
-
## 安全
-- 永远不要泄露私人数据
+- 永远不要泄露秘钥等私人数据
- 不要在未经询问的情况下运行破坏性命令
- 当有疑问时,先问
diff --git a/agent/protocol/agent.py b/agent/protocol/agent.py
index 5c4f994..1b031e0 100644
--- a/agent/protocol/agent.py
+++ b/agent/protocol/agent.py
@@ -1,5 +1,6 @@
import json
import time
+import threading
from common.log import logger
from agent.protocol.models import LLMRequest, LLMModel
@@ -43,6 +44,7 @@ class Agent:
self.output_mode = output_mode
self.last_usage = None # Store last API response usage info
self.messages = [] # Unified message history for stream mode
+ self.messages_lock = threading.Lock() # Lock for thread-safe message operations
self.memory_manager = memory_manager # Memory manager for auto memory flush
self.workspace_dir = workspace_dir # Workspace directory
self.enable_skills = enable_skills # Skills enabled flag
@@ -57,7 +59,7 @@ class Agent:
try:
from agent.skills import SkillManager
self.skill_manager = SkillManager(workspace_dir=workspace_dir)
- logger.info(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
+ logger.debug(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
except Exception as e:
logger.warning(f"Failed to initialize SkillManager: {e}")
@@ -335,7 +337,8 @@ class Agent:
"""
# Clear history if requested
if clear_history:
- self.messages = []
+ with self.messages_lock:
+ self.messages = []
# Get model to use
if not self.model:
@@ -344,7 +347,17 @@ class Agent:
# Get full system prompt with skills
full_system_prompt = self.get_full_system_prompt(skill_filter=skill_filter)
- # Create stream executor with agent's message history
+ # Create a copy of messages for this execution to avoid concurrent modification
+ # Record the original length to track which messages are new
+ with self.messages_lock:
+ messages_copy = self.messages.copy()
+ original_length = len(self.messages)
+
+ # Get max_context_turns from config
+ from config import conf
+ max_context_turns = conf().get("agent_max_context_turns", 30)
+
+ # Create stream executor with copied message history
executor = AgentStreamExecutor(
agent=self,
model=self.model,
@@ -352,14 +365,21 @@ class Agent:
tools=self.tools,
max_turns=self.max_steps,
on_event=on_event,
- messages=self.messages # Pass agent's message history
+ messages=messages_copy, # Pass copied message history
+ max_context_turns=max_context_turns
)
# Execute
response = executor.run_stream(user_message)
- # Update agent's message history from executor
- self.messages = executor.messages
+ # Append only the NEW messages from this execution (thread-safe)
+ # This allows concurrent requests to both contribute to history
+ with self.messages_lock:
+ new_messages = executor.messages[original_length:]
+ self.messages.extend(new_messages)
+
+ # Store executor reference for agent_bridge to access files_to_send
+ self.stream_executor = executor
# Execute all post-process tools
self._execute_post_process_tools()
diff --git a/agent/protocol/agent_stream.py b/agent/protocol/agent_stream.py
index 3ad6eb5..2b1c883 100644
--- a/agent/protocol/agent_stream.py
+++ b/agent/protocol/agent_stream.py
@@ -7,9 +7,9 @@ import json
import time
from typing import List, Dict, Any, Optional, Callable
-from common.log import logger
from agent.protocol.models import LLMRequest, LLMModel
from agent.tools.base_tool import BaseTool, ToolResult
+from common.log import logger
class AgentStreamExecutor:
@@ -31,7 +31,8 @@ class AgentStreamExecutor:
tools: List[BaseTool],
max_turns: int = 50,
on_event: Optional[Callable] = None,
- messages: Optional[List[Dict]] = None
+ messages: Optional[List[Dict]] = None,
+ max_context_turns: int = 30
):
"""
Initialize stream executor
@@ -44,6 +45,7 @@ class AgentStreamExecutor:
max_turns: Maximum number of turns
on_event: Event callback function
messages: Optional existing message history (for persistent conversations)
+ max_context_turns: Maximum number of conversation turns to keep in context
"""
self.agent = agent
self.model = model
@@ -52,12 +54,16 @@ class AgentStreamExecutor:
self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools
self.max_turns = max_turns
self.on_event = on_event
+ self.max_context_turns = max_context_turns
# Message history - use provided messages or create new list
self.messages = messages if messages is not None else []
# Tool failure tracking for retry protection
self.tool_failure_history = [] # List of (tool_name, args_hash, success) tuples
+
+ # Track files to send (populated by read tool)
+ self.files_to_send = [] # List of file metadata dicts
def _emit_event(self, event_type: str, data: dict = None):
"""Emit event"""
@@ -78,12 +84,15 @@ class AgentStreamExecutor:
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
return hashlib.md5(args_str.encode()).hexdigest()[:8]
- def _check_consecutive_failures(self, tool_name: str, args: dict) -> tuple[bool, str]:
+ def _check_consecutive_failures(self, tool_name: str, args: dict) -> tuple[bool, str, bool]:
"""
Check if tool has failed too many times consecutively
Returns:
- (should_stop, reason)
+ (should_stop, reason, is_critical)
+ - should_stop: Whether to stop tool execution
+ - reason: Reason for stopping
+ - is_critical: Whether to abort entire conversation (True for 8+ failures)
"""
args_hash = self._hash_args(args)
@@ -99,7 +108,7 @@ class AgentStreamExecutor:
break # Different tool or args, stop counting
if same_args_failures >= 3:
- return True, f"Tool '{tool_name}' with same arguments failed {same_args_failures} times consecutively. Stopping to prevent infinite loop."
+ return True, f"工具 '{tool_name}' 使用相同参数连续失败 {same_args_failures} 次,停止执行以防止无限循环", False
# Count consecutive failures for same tool (any args)
same_tool_failures = 0
@@ -112,10 +121,15 @@ class AgentStreamExecutor:
else:
break # Different tool, stop counting
- if same_tool_failures >= 6:
- return True, f"Tool '{tool_name}' failed {same_tool_failures} times consecutively (with any arguments). Stopping to prevent infinite loop."
+ # Hard stop at 8 failures - abort with critical message
+ if same_tool_failures >= 8:
+ return True, f"抱歉,我没能完成这个任务。可能是我理解有误或者当前方法不太合适。\n\n建议你:\n• 换个方式描述需求试试\n• 把任务拆分成更小的步骤\n• 或者换个思路来解决", True
- return False, ""
+ # Warning at 6 failures
+ if same_tool_failures >= 6:
+ return True, f"工具 '{tool_name}' 连续失败 {same_tool_failures} 次(使用不同参数),停止执行以防止无限循环", False
+
+ return False, "", False
def _record_tool_result(self, tool_name: str, args: dict, success: bool):
"""Record tool execution result for failure tracking"""
@@ -136,10 +150,7 @@ class AgentStreamExecutor:
Final response text
"""
# Log user message with model info
- logger.info(f"{'='*50}")
- logger.info(f"🤖 Model: {self.model.model}")
- logger.info(f"👤 用户: {user_message}")
- logger.info(f"{'='*50}")
+ logger.info(f"🤖 {self.model.model} | 👤 {user_message}")
# Add user message (Claude format - use content blocks for consistency)
self.messages.append({
@@ -160,54 +171,74 @@ class AgentStreamExecutor:
try:
while turn < self.max_turns:
turn += 1
- logger.info(f"第 {turn} 轮")
+ logger.debug(f"第 {turn} 轮")
self._emit_event("turn_start", {"turn": turn})
# Check if memory flush is needed (before calling LLM)
+ # 使用独立的 flush 阈值(50K tokens 或 20 轮)
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
usage = self.agent.last_usage
if usage and 'input_tokens' in usage:
current_tokens = usage.get('input_tokens', 0)
- context_window = self.agent._get_model_context_window()
- # Use configured reserve_tokens or calculate based on context window
- reserve_tokens = self.agent._get_context_reserve_tokens()
- # Use smaller soft_threshold to trigger flush earlier (e.g., at 50K tokens)
- soft_threshold = 10000 # Trigger 10K tokens before limit
if self.agent.memory_manager.should_flush_memory(
- current_tokens=current_tokens,
- context_window=context_window,
- reserve_tokens=reserve_tokens,
- soft_threshold=soft_threshold
+ current_tokens=current_tokens
):
self._emit_event("memory_flush_start", {
"current_tokens": current_tokens,
- "threshold": context_window - reserve_tokens - soft_threshold
+ "turn_count": self.agent.memory_manager.flush_manager.turn_count
})
# TODO: Execute memory flush in background
# This would require async support
- logger.info(f"Memory flush recommended at {current_tokens} tokens")
+ logger.info(
+ f"Memory flush recommended: tokens={current_tokens}, turns={self.agent.memory_manager.flush_manager.turn_count}")
- # Call LLM
- assistant_msg, tool_calls = self._call_llm_stream()
+ # Call LLM (enable retry_on_empty for better reliability)
+ assistant_msg, tool_calls = self._call_llm_stream(retry_on_empty=True)
final_response = assistant_msg
# No tool calls, end loop
if not tool_calls:
# 检查是否返回了空响应
if not assistant_msg:
- logger.warning(f"[Agent] LLM returned empty response (no content and no tool calls)")
+ logger.warning(f"[Agent] LLM returned empty response after retry (no content and no tool calls)")
+ logger.info(f"[Agent] This usually happens when LLM thinks the task is complete after tool execution")
- # 生成通用的友好提示
- final_response = (
- "抱歉,我暂时无法生成回复。请尝试换一种方式描述你的需求,或稍后再试。"
- )
- logger.info(f"Generated fallback response for empty LLM output")
+ # 如果之前有工具调用,强制要求 LLM 生成文本回复
+ if turn > 1:
+ logger.info(f"[Agent] Requesting explicit response from LLM...")
+
+ # 添加一条消息,明确要求回复用户
+ self.messages.append({
+ "role": "user",
+ "content": [{
+ "type": "text",
+ "text": "请向用户说明刚才工具执行的结果或回答用户的问题。"
+ }]
+ })
+
+ # 再调用一次 LLM
+ assistant_msg, tool_calls = self._call_llm_stream(retry_on_empty=False)
+ final_response = assistant_msg
+
+ # 如果还是空,才使用 fallback
+ if not assistant_msg and not tool_calls:
+ logger.warning(f"[Agent] Still empty after explicit request")
+ final_response = (
+ "抱歉,我暂时无法生成回复。请尝试换一种方式描述你的需求,或稍后再试。"
+ )
+ logger.info(f"Generated fallback response for empty LLM output")
+ else:
+ # 第一轮就空回复,直接 fallback
+ final_response = (
+ "抱歉,我暂时无法生成回复。请尝试换一种方式描述你的需求,或稍后再试。"
+ )
+ logger.info(f"Generated fallback response for empty LLM output")
else:
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
- logger.info(f"✅ 完成 (无工具调用)")
+ logger.debug(f"✅ 完成 (无工具调用)")
self._emit_event("turn_end", {
"turn": turn,
"has_tool_calls": False
@@ -233,6 +264,20 @@ class AgentStreamExecutor:
result = self._execute_tool(tool_call)
tool_results.append(result)
+ # Check if this is a file to send (from read tool)
+ if result.get("status") == "success" and isinstance(result.get("result"), dict):
+ result_data = result.get("result")
+ if result_data.get("type") == "file_to_send":
+ # Store file metadata for later sending
+ self.files_to_send.append(result_data)
+ logger.info(f"📎 检测到待发送文件: {result_data.get('file_name', result_data.get('path'))}")
+
+ # Check for critical error - abort entire conversation
+ if result.get("status") == "critical_error":
+ logger.error(f"💥 检测到严重错误,终止对话")
+ final_response = result.get('result', '任务执行失败')
+ return final_response
+
# Log tool result in compact format
status_emoji = "✅" if result.get("status") == "success" else "❌"
result_data = result.get('result', '')
@@ -305,11 +350,37 @@ class AgentStreamExecutor:
})
if turn >= self.max_turns:
- logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
- if not final_response:
+ logger.warning(f"⚠️ 已达到最大决策步数限制: {self.max_turns}")
+
+ # Force model to summarize without tool calls
+ logger.info(f"[Agent] Requesting summary from LLM after reaching max steps...")
+
+ # Add a system message to force summary
+ self.messages.append({
+ "role": "user",
+ "content": [{
+ "type": "text",
+ "text": f"你已经执行了{turn}个决策步骤,达到了单次运行的最大步数限制。请总结一下你目前的执行过程和结果,告诉用户当前的进展情况。不要再调用工具,直接用文字回复。"
+ }]
+ })
+
+ # Call LLM one more time to get summary (without retry to avoid loops)
+ try:
+ summary_response, summary_tools = self._call_llm_stream(retry_on_empty=False)
+ if summary_response:
+ final_response = summary_response
+ logger.info(f"💭 Summary: {summary_response[:150]}{'...' if len(summary_response) > 150 else ''}")
+ else:
+ # Fallback if model still doesn't respond
+ final_response = (
+ f"我已经执行了{turn}个决策步骤,达到了单次运行的步数上限。"
+ "任务可能还未完全完成,建议你将任务拆分成更小的步骤,或者换一种方式描述需求。"
+ )
+ except Exception as e:
+ logger.warning(f"Failed to get summary from LLM: {e}")
final_response = (
- "抱歉,我在处理你的请求时遇到了一些困难,尝试了多次仍未能完成。"
- "请尝试简化你的问题,或换一种方式描述。"
+ f"我已经执行了{turn}个决策步骤,达到了单次运行的步数上限。"
+ "任务可能还未完全完成,建议你将任务拆分成更小的步骤,或者换一种方式描述需求。"
)
except Exception as e:
@@ -318,9 +389,13 @@ class AgentStreamExecutor:
raise
finally:
- logger.info(f"🏁 完成({turn}轮)")
+ logger.debug(f"🏁 完成({turn}轮)")
self._emit_event("agent_end", {"final_response": final_response})
+ # 每轮对话结束后增加计数(用户消息+AI回复=1轮)
+ if self.agent.memory_manager:
+ self.agent.memory_manager.increment_turn()
+
return final_response
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3) -> tuple[str, List[Dict]]:
@@ -380,6 +455,7 @@ class AgentStreamExecutor:
# Streaming response
full_content = ""
tool_calls_buffer = {} # {index: {id, name, arguments}}
+ stop_reason = None # Track why the stream stopped
try:
stream = self.model.call_stream(request)
@@ -392,21 +468,47 @@ class AgentStreamExecutor:
if isinstance(error_data, dict):
error_msg = error_data.get("message", chunk.get("message", "Unknown error"))
error_code = error_data.get("code", "")
+ error_type = error_data.get("type", "")
else:
error_msg = chunk.get("message", str(error_data))
error_code = ""
+ error_type = ""
status_code = chunk.get("status_code", "N/A")
- logger.error(f"API Error: {error_msg} (Status: {status_code}, Code: {error_code})")
- logger.error(f"Full error chunk: {chunk}")
- # Raise exception with full error message for retry logic
- raise Exception(f"{error_msg} (Status: {status_code})")
+ # Log error with all available information
+ logger.error(f"🔴 Stream API Error:")
+ logger.error(f" Message: {error_msg}")
+ logger.error(f" Status Code: {status_code}")
+ logger.error(f" Error Code: {error_code}")
+ logger.error(f" Error Type: {error_type}")
+ logger.error(f" Full chunk: {chunk}")
+
+ # Check if this is a context overflow error (keyword-based, works for all models)
+ # Don't rely on specific status codes as different providers use different codes
+ error_msg_lower = error_msg.lower()
+ is_overflow = any(keyword in error_msg_lower for keyword in [
+ 'context length exceeded', 'maximum context length', 'prompt is too long',
+ 'context overflow', 'context window', 'too large', 'exceeds model context',
+ 'request_too_large', 'request exceeds the maximum size', 'tokens exceed'
+ ])
+
+ if is_overflow:
+ # Mark as context overflow for special handling
+ raise Exception(f"[CONTEXT_OVERFLOW] {error_msg} (Status: {status_code})")
+ else:
+ # Raise exception with full error message for retry logic
+ raise Exception(f"{error_msg} (Status: {status_code}, Code: {error_code}, Type: {error_type})")
# Parse chunk
if isinstance(chunk, dict) and "choices" in chunk:
choice = chunk["choices"][0]
delta = choice.get("delta", {})
+
+ # Capture finish_reason if present
+ finish_reason = choice.get("finish_reason")
+ if finish_reason:
+ stop_reason = finish_reason
# Handle text content
if "content" in delta and delta["content"]:
@@ -437,9 +539,46 @@ class AgentStreamExecutor:
tool_calls_buffer[index]["arguments"] += func["arguments"]
except Exception as e:
- error_str = str(e).lower()
+ error_str = str(e)
+ error_str_lower = error_str.lower()
+
+ # Check if error is context overflow (non-retryable, needs session reset)
+ # Method 1: Check for special marker (set in stream error handling above)
+ is_context_overflow = '[context_overflow]' in error_str_lower
+
+ # Method 2: Fallback to keyword matching for non-stream errors
+ if not is_context_overflow:
+ is_context_overflow = any(keyword in error_str_lower for keyword in [
+ 'context length exceeded', 'maximum context length', 'prompt is too long',
+ 'context overflow', 'context window', 'too large', 'exceeds model context',
+ 'request_too_large', 'request exceeds the maximum size'
+ ])
+
+ # Check if error is message format error (incomplete tool_use/tool_result pairs)
+ # This happens when previous conversation had tool failures
+ is_message_format_error = any(keyword in error_str_lower for keyword in [
+ 'tool_use', 'tool_result', 'without', 'immediately after',
+ 'corresponding', 'must have', 'each'
+ ]) and 'status: 400' in error_str_lower
+
+ if is_context_overflow or is_message_format_error:
+ error_type = "context overflow" if is_context_overflow else "message format error"
+ logger.error(f"💥 {error_type} detected: {e}")
+ # Clear message history to recover
+ logger.warning("🔄 Clearing conversation history to recover")
+ self.messages.clear()
+ # Raise special exception with user-friendly message
+ if is_context_overflow:
+ raise Exception(
+ "抱歉,对话历史过长导致上下文溢出。我已清空历史记录,请重新描述你的需求。"
+ )
+ else:
+ raise Exception(
+ "抱歉,之前的对话出现了问题。我已清空历史记录,请重新发送你的消息。"
+ )
+
# Check if error is retryable (timeout, connection, rate limit, server busy, etc.)
- is_retryable = any(keyword in error_str for keyword in [
+ is_retryable = any(keyword in error_str_lower for keyword in [
'timeout', 'timed out', 'connection', 'network',
'rate limit', 'overloaded', 'unavailable', 'busy', 'retry',
'429', '500', '502', '503', '504', '512'
@@ -469,15 +608,19 @@ class AgentStreamExecutor:
try:
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
except json.JSONDecodeError as e:
- logger.error(f"Failed to parse tool arguments: {tc['arguments']}")
+ args_preview = tc['arguments'][:200] if len(tc['arguments']) > 200 else tc['arguments']
+ logger.error(f"Failed to parse tool arguments for {tc['name']}")
+ logger.error(f"Arguments length: {len(tc['arguments'])} chars")
+ logger.error(f"Arguments preview: {args_preview}...")
logger.error(f"JSON decode error: {e}")
+
# Return a clear error message to the LLM instead of empty dict
# This helps the LLM understand what went wrong
tool_calls.append({
"id": tc["id"],
"name": tc["name"],
"arguments": {},
- "_parse_error": f"Invalid JSON in tool arguments: {tc['arguments'][:100]}... Error: {str(e)}"
+ "_parse_error": f"Invalid JSON in tool arguments: {args_preview}... Error: {str(e)}. Tip: For large content, consider splitting into smaller chunks or using a different approach."
})
continue
@@ -489,11 +632,12 @@ class AgentStreamExecutor:
# Check for empty response and retry once if enabled
if retry_on_empty and not full_content and not tool_calls:
- logger.warning(f"⚠️ LLM returned empty response, retrying once...")
+ logger.warning(f"⚠️ LLM returned empty response (stop_reason: {stop_reason}), retrying once...")
self._emit_event("message_end", {
"content": "",
"tool_calls": [],
- "empty_retry": True
+ "empty_retry": True,
+ "stop_reason": stop_reason
})
# Retry without retry flag to avoid infinite loop
return self._call_llm_stream(
@@ -560,16 +704,25 @@ class AgentStreamExecutor:
return result
# Check for consecutive failures (retry protection)
- should_stop, stop_reason = self._check_consecutive_failures(tool_name, arguments)
+ should_stop, stop_reason, is_critical = self._check_consecutive_failures(tool_name, arguments)
if should_stop:
logger.error(f"🛑 {stop_reason}")
self._record_tool_result(tool_name, arguments, False)
- # 返回错误给 LLM,让它尝试其他方法
- result = {
- "status": "error",
- "result": f"{stop_reason}\n\nThis approach is not working. Please try a completely different method or ask the user for more information/clarification.",
- "execution_time": 0
- }
+
+ if is_critical:
+ # Critical failure - abort entire conversation
+ result = {
+ "status": "critical_error",
+ "result": stop_reason,
+ "execution_time": 0
+ }
+ else:
+ # Normal failure - let LLM try different approach
+ result = {
+ "status": "error",
+ "result": f"{stop_reason}\n\n当前方法行不通,请尝试完全不同的方法或向用户询问更多信息。",
+ "execution_time": 0
+ }
return result
self._emit_event("tool_execution_start", {
@@ -656,52 +809,174 @@ class AgentStreamExecutor:
logger.warning(f"⚠️ Removing incomplete tool_use message from history")
self.messages.pop()
+ def _identify_complete_turns(self) -> List[Dict]:
+ """
+ 识别完整的对话轮次
+
+ 一个完整轮次包括:
+ 1. 用户消息(text)
+ 2. AI 回复(可能包含 tool_use)
+ 3. 工具结果(tool_result,如果有)
+ 4. 后续 AI 回复(如果有)
+
+ Returns:
+ List of turns, each turn is a dict with 'messages' list
+ """
+ turns = []
+ current_turn = {'messages': []}
+
+ for msg in self.messages:
+ role = msg.get('role')
+ content = msg.get('content', [])
+
+ if role == 'user':
+ # 检查是否是用户查询(不是工具结果)
+ is_user_query = False
+ if isinstance(content, list):
+ is_user_query = any(
+ block.get('type') == 'text'
+ for block in content
+ if isinstance(block, dict)
+ )
+ elif isinstance(content, str):
+ is_user_query = True
+
+ if is_user_query:
+ # 开始新轮次
+ if current_turn['messages']:
+ turns.append(current_turn)
+ current_turn = {'messages': [msg]}
+ else:
+ # 工具结果,属于当前轮次
+ current_turn['messages'].append(msg)
+ else:
+ # AI 回复,属于当前轮次
+ current_turn['messages'].append(msg)
+
+ # 添加最后一个轮次
+ if current_turn['messages']:
+ turns.append(current_turn)
+
+ return turns
+
+ def _estimate_turn_tokens(self, turn: Dict) -> int:
+ """估算一个轮次的 tokens"""
+ return sum(
+ self.agent._estimate_message_tokens(msg)
+ for msg in turn['messages']
+ )
+
def _trim_messages(self):
"""
- Trim message history to stay within context limits.
- Uses agent's context management configuration.
+ 智能清理消息历史,保持对话完整性
+
+ 使用完整轮次作为清理单位,确保:
+ 1. 不会在对话中间截断
+ 2. 工具调用链(tool_use + tool_result)保持完整
+ 3. 每轮对话都是完整的(用户消息 + AI回复 + 工具调用)
"""
if not self.messages or not self.agent:
return
- # Get context window and reserve tokens from agent
+ # Step 1: 识别完整轮次
+ turns = self._identify_complete_turns()
+
+ if not turns:
+ return
+
+ # Step 2: 轮次限制 - 保留最近 N 轮
+ if len(turns) > self.max_context_turns:
+ removed_turns = len(turns) - self.max_context_turns
+ turns = turns[-self.max_context_turns:] # 保留最近的轮次
+
+ logger.info(
+ f"💾 上下文轮次超限: {len(turns) + removed_turns} > {self.max_context_turns},"
+ f"移除最早的 {removed_turns} 轮完整对话"
+ )
+
+ # Step 3: Token 限制 - 保留完整轮次
+ # Get context window from agent (based on model)
context_window = self.agent._get_model_context_window()
- reserve_tokens = self.agent._get_context_reserve_tokens()
- max_tokens = context_window - reserve_tokens
- # Estimate current tokens
- current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages)
+ # Use configured max_context_tokens if available
+ if hasattr(self.agent, 'max_context_tokens') and self.agent.max_context_tokens:
+ max_tokens = self.agent.max_context_tokens
+ else:
+ # Reserve 10% for response generation
+ reserve_tokens = int(context_window * 0.1)
+ max_tokens = context_window - reserve_tokens
- # Add system prompt tokens
+ # Estimate system prompt tokens
system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt})
- current_tokens += system_tokens
+ available_tokens = max_tokens - system_tokens
- # If under limit, no need to trim
- if current_tokens <= max_tokens:
+ # Calculate current tokens
+ current_tokens = sum(self._estimate_turn_tokens(turn) for turn in turns)
+
+ # If under limit, reconstruct messages and return
+ if current_tokens + system_tokens <= max_tokens:
+ # Reconstruct message list from turns
+ new_messages = []
+ for turn in turns:
+ new_messages.extend(turn['messages'])
+
+ old_count = len(self.messages)
+ self.messages = new_messages
+
+ # Log if we removed messages due to turn limit
+ if old_count > len(self.messages):
+ logger.info(f" 重建消息列表: {old_count} -> {len(self.messages)} 条消息")
return
- # Keep messages from newest, accumulating tokens
- available_tokens = max_tokens - system_tokens
- kept_messages = []
+ # Token limit exceeded - keep complete turns from newest
+ logger.info(
+ f"🔄 上下文tokens超限: ~{current_tokens + system_tokens} > {max_tokens},"
+ f"将按完整轮次移除最早的对话"
+ )
+
+ # 从最新轮次开始,反向累加(保持完整轮次)
+ kept_turns = []
accumulated_tokens = 0
-
- for msg in reversed(self.messages):
- msg_tokens = self.agent._estimate_message_tokens(msg)
- if accumulated_tokens + msg_tokens <= available_tokens:
- kept_messages.insert(0, msg)
- accumulated_tokens += msg_tokens
+ min_turns = 3 # 尽量保留至少 3 轮,但不强制(避免超出 token 限制)
+
+ for i, turn in enumerate(reversed(turns)):
+ turn_tokens = self._estimate_turn_tokens(turn)
+ turns_from_end = i + 1
+
+ # 检查是否超出限制
+ if accumulated_tokens + turn_tokens <= available_tokens:
+ kept_turns.insert(0, turn)
+ accumulated_tokens += turn_tokens
else:
+ # 超出限制
+ # 如果还没有保留足够的轮次,且这是最后的机会,尝试保留
+ if len(kept_turns) < min_turns and turns_from_end <= min_turns:
+ # 检查是否严重超出(超出 20% 以上则放弃)
+ overflow_ratio = (accumulated_tokens + turn_tokens - available_tokens) / available_tokens
+ if overflow_ratio < 0.2: # 允许最多超出 20%
+ kept_turns.insert(0, turn)
+ accumulated_tokens += turn_tokens
+ logger.debug(f" 为保留最少轮次,允许超出 {overflow_ratio*100:.1f}%")
+ continue
+ # 停止保留更早的轮次
break
-
+
+ # 重建消息列表
+ new_messages = []
+ for turn in kept_turns:
+ new_messages.extend(turn['messages'])
+
old_count = len(self.messages)
- self.messages = kept_messages
+ old_turn_count = len(turns)
+ self.messages = new_messages
new_count = len(self.messages)
-
+ new_turn_count = len(kept_turns)
+
if old_count > new_count:
logger.info(
- f"Context trimmed: {old_count} -> {new_count} messages "
- f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, "
- f"limit: {max_tokens})"
+ f" 移除了 {old_turn_count - new_turn_count} 轮对话 "
+ f"({old_count} -> {new_count} 条消息,"
+ f"~{current_tokens + system_tokens} -> ~{accumulated_tokens + system_tokens} tokens)"
)
def _prepare_messages(self) -> List[Dict[str, Any]]:
diff --git a/agent/skills/config.py b/agent/skills/config.py
index bbcd9f6..21e1737 100644
--- a/agent/skills/config.py
+++ b/agent/skills/config.py
@@ -70,33 +70,27 @@ def should_include_skill(
entry: SkillEntry,
config: Optional[Dict] = None,
current_platform: Optional[str] = None,
- lenient: bool = True,
) -> bool:
"""
Determine if a skill should be included based on requirements.
- Similar to clawdbot's shouldIncludeSkill logic, but with lenient mode:
- - In lenient mode (default): Only check explicit disable and platform, ignore missing requirements
- - In strict mode: Check all requirements (binary, env vars, config)
+ Simple rule: Skills are auto-enabled if their requirements are met.
+ - Has required API keys → enabled
+ - Missing API keys → disabled
+ - Wrong keys → enabled but will fail at runtime (LLM will handle error)
:param entry: SkillEntry to check
- :param config: Configuration dictionary
+ :param config: Configuration dictionary (currently unused, reserved for future)
:param current_platform: Current platform (default: auto-detect)
- :param lenient: If True, ignore missing requirements and load all skills (default: True)
:return: True if skill should be included
"""
metadata = entry.metadata
- skill_name = entry.skill.name
- skill_config = get_skill_config(config, skill_name)
-
- # Always check if skill is explicitly disabled in config
- if skill_config and skill_config.get('enabled') is False:
- return False
+ # No metadata = always include (no requirements)
if not metadata:
return True
- # Always check platform requirements (can't work on wrong platform)
+ # Check platform requirements (can't work on wrong platform)
if metadata.os:
platform_name = current_platform or resolve_runtime_platform()
# Map common platform names
@@ -114,12 +108,7 @@ def should_include_skill(
if metadata.always:
return True
- # In lenient mode, skip requirement checks and load all skills
- # Skills will fail gracefully at runtime if requirements are missing
- if lenient:
- return True
-
- # Strict mode: Check all requirements
+ # Check requirements
if metadata.requires:
# Check required binaries (all must be present)
required_bins = metadata.requires.get('bins', [])
@@ -133,29 +122,13 @@ def should_include_skill(
if not has_any_binary(any_bins):
return False
- # Check environment variables (with config fallback)
+ # Check environment variables (API keys)
+ # Simple rule: All required env vars must be set
required_env = metadata.requires.get('env', [])
if required_env:
for env_name in required_env:
- # Check in order: 1) env var, 2) skill config env, 3) skill config apiKey (if primaryEnv)
- if has_env_var(env_name):
- continue
- if skill_config:
- # Check skill config env dict
- skill_env = skill_config.get('env', {})
- if isinstance(skill_env, dict) and env_name in skill_env:
- continue
- # Check skill config apiKey (if this is the primaryEnv)
- if metadata.primary_env == env_name and skill_config.get('apiKey'):
- continue
- # Requirement not satisfied
- return False
-
- # Check config paths
- required_config = metadata.requires.get('config', [])
- if required_config and config:
- for config_path in required_config:
- if not is_config_path_truthy(config, config_path):
+ if not has_env_var(env_name):
+ # Missing required API key → disable skill
return False
return True
diff --git a/agent/skills/formatter.py b/agent/skills/formatter.py
index e77d7d6..7868e09 100644
--- a/agent/skills/formatter.py
+++ b/agent/skills/formatter.py
@@ -34,6 +34,7 @@ def format_skills_for_prompt(skills: List[Skill]) -> str:
lines.append(f" {_escape_xml(skill.name)}")
lines.append(f" {_escape_xml(skill.description)}")
lines.append(f" {_escape_xml(skill.file_path)}")
+ lines.append(f" {_escape_xml(skill.base_dir)}")
lines.append(" ")
lines.append("")
diff --git a/agent/skills/frontmatter.py b/agent/skills/frontmatter.py
index 565c1f7..9905e29 100644
--- a/agent/skills/frontmatter.py
+++ b/agent/skills/frontmatter.py
@@ -23,7 +23,22 @@ def parse_frontmatter(content: str) -> Dict[str, Any]:
frontmatter_text = match.group(1)
- # Simple YAML-like parsing (supports key: value format)
+ # Try to use PyYAML for proper YAML parsing
+ try:
+ import yaml
+ frontmatter = yaml.safe_load(frontmatter_text)
+ if not isinstance(frontmatter, dict):
+ frontmatter = {}
+ return frontmatter
+ except ImportError:
+ # Fallback to simple parsing if PyYAML not available
+ pass
+ except Exception:
+ # If YAML parsing fails, fall back to simple parsing
+ pass
+
+ # Simple YAML-like parsing (supports key: value format only)
+ # This is a fallback for when PyYAML is not available
for line in frontmatter_text.split('\n'):
line = line.strip()
if not line or line.startswith('#'):
@@ -72,10 +87,8 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
if not isinstance(metadata_raw, dict):
return None
- # Support both 'moltbot' and 'cow' keys for compatibility
- meta_obj = metadata_raw.get('moltbot') or metadata_raw.get('cow')
- if not meta_obj or not isinstance(meta_obj, dict):
- return None
+ # Use metadata_raw directly (COW format)
+ meta_obj = metadata_raw
# Parse install specs
install_specs = []
diff --git a/agent/skills/loader.py b/agent/skills/loader.py
index 0bc8f4a..cc77d32 100644
--- a/agent/skills/loader.py
+++ b/agent/skills/loader.py
@@ -137,6 +137,10 @@ class SkillLoader:
name = frontmatter.get('name', parent_dir_name)
description = frontmatter.get('description', '')
+ # Special handling for linkai-agent: dynamically load apps from config.json
+ if name == 'linkai-agent':
+ description = self._load_linkai_agent_description(skill_dir, description)
+
if not description or not description.strip():
diagnostics.append(f"Skill {name} has no description: {file_path}")
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
@@ -161,6 +165,45 @@ class SkillLoader:
return LoadSkillsResult(skills=[skill], diagnostics=diagnostics)
+ def _load_linkai_agent_description(self, skill_dir: str, default_description: str) -> str:
+ """
+ Dynamically load LinkAI agent description from config.json
+
+ :param skill_dir: Skill directory
+ :param default_description: Default description from SKILL.md
+ :return: Dynamic description with app list
+ """
+ import json
+
+ config_path = os.path.join(skill_dir, "config.json")
+ template_path = os.path.join(skill_dir, "config.json.template")
+
+ # Try to load config.json or fallback to template
+ config_file = config_path if os.path.exists(config_path) else template_path
+
+ if not os.path.exists(config_file):
+ return default_description
+
+ try:
+ with open(config_file, 'r', encoding='utf-8') as f:
+ config = json.load(f)
+
+ apps = config.get("apps", [])
+ if not apps:
+ return default_description
+
+ # Build dynamic description with app details
+ app_descriptions = "; ".join([
+ f"{app['app_name']}({app['app_code']}: {app['app_description']})"
+ for app in apps
+ ])
+
+ return f"Call LinkAI apps/workflows. {app_descriptions}"
+
+ except Exception as e:
+ logger.warning(f"[SkillLoader] Failed to load linkai-agent config: {e}")
+ return default_description
+
def load_all_skills(
self,
managed_dir: Optional[str] = None,
@@ -216,7 +259,7 @@ class SkillLoader:
for diag in all_diagnostics[:5]: # Log first 5
logger.debug(f" - {diag}")
- logger.info(f"Loaded {len(skill_map)} skills from all sources")
+ logger.debug(f"Loaded {len(skill_map)} skills from all sources")
return skill_map
diff --git a/agent/skills/manager.py b/agent/skills/manager.py
index 580dc0f..bf9593f 100644
--- a/agent/skills/manager.py
+++ b/agent/skills/manager.py
@@ -59,7 +59,7 @@ class SkillManager:
extra_dirs=self.extra_dirs,
)
- logger.info(f"SkillManager: Loaded {len(self.skills)} skills")
+ logger.debug(f"SkillManager: Loaded {len(self.skills)} skills")
def get_skill(self, name: str) -> Optional[SkillEntry]:
"""
@@ -82,32 +82,24 @@ class SkillManager:
self,
skill_filter: Optional[List[str]] = None,
include_disabled: bool = False,
- check_requirements: bool = False, # Changed default to False for lenient loading
- lenient: bool = True, # New parameter for lenient mode
) -> List[SkillEntry]:
"""
Filter skills based on criteria.
- By default (lenient=True), all skills are loaded regardless of missing requirements.
- Skills will fail gracefully at runtime if requirements are not met.
+ Simple rule: Skills are auto-enabled if requirements are met.
+ - Has required API keys → included
+ - Missing API keys → excluded
:param skill_filter: List of skill names to include (None = all)
:param include_disabled: Whether to include skills with disable_model_invocation=True
- :param check_requirements: Whether to check skill requirements (default: False)
- :param lenient: If True, ignore missing requirements (default: True)
:return: Filtered list of skill entries
"""
from agent.skills.config import should_include_skill
entries = list(self.skills.values())
- # Check requirements (platform, explicit disable, etc.)
- # In lenient mode, only checks platform and explicit disable
- if check_requirements or not lenient:
- entries = [e for e in entries if should_include_skill(e, self.config, lenient=lenient)]
- else:
- # Lenient mode: only check explicit disable and platform
- entries = [e for e in entries if should_include_skill(e, self.config, lenient=True)]
+ # Check requirements (platform, binaries, env vars)
+ entries = [e for e in entries if should_include_skill(e, self.config)]
# Apply skill filter
if skill_filter is not None:
diff --git a/agent/tools/__init__.py b/agent/tools/__init__.py
index 76f7e2e..3cba117 100644
--- a/agent/tools/__init__.py
+++ b/agent/tools/__init__.py
@@ -2,55 +2,57 @@
from agent.tools.base_tool import BaseTool
from agent.tools.tool_manager import ToolManager
-# Import basic tools (no external dependencies)
-from agent.tools.calculator.calculator import Calculator
-
# Import file operation tools
from agent.tools.read.read import Read
from agent.tools.write.write import Write
from agent.tools.edit.edit import Edit
from agent.tools.bash.bash import Bash
-from agent.tools.grep.grep import Grep
-from agent.tools.find.find import Find
from agent.tools.ls.ls import Ls
+from agent.tools.send.send import Send
# Import memory tools
from agent.tools.memory.memory_search import MemorySearchTool
from agent.tools.memory.memory_get import MemoryGetTool
-# Import web tools
-from agent.tools.web_fetch.web_fetch import WebFetch
-
# Import tools with optional dependencies
def _import_optional_tools():
"""Import tools that have optional dependencies"""
+ from common.log import logger
tools = {}
- # Google Search (requires requests)
+ # EnvConfig Tool (requires python-dotenv)
try:
- from agent.tools.google_search.google_search import GoogleSearch
- tools['GoogleSearch'] = GoogleSearch
- except ImportError:
- pass
+ from agent.tools.env_config.env_config import EnvConfig
+ tools['EnvConfig'] = EnvConfig
+ except ImportError as e:
+ logger.error(
+ f"[Tools] EnvConfig tool not loaded - missing dependency: {e}\n"
+ f" To enable environment variable management, run:\n"
+ f" pip install python-dotenv>=1.0.0"
+ )
+ except Exception as e:
+ logger.error(f"[Tools] EnvConfig tool failed to load: {e}")
- # File Save (may have dependencies)
+ # Scheduler Tool (requires croniter)
try:
- from agent.tools.file_save.file_save import FileSave
- tools['FileSave'] = FileSave
- except ImportError:
- pass
+ from agent.tools.scheduler.scheduler_tool import SchedulerTool
+ tools['SchedulerTool'] = SchedulerTool
+ except ImportError as e:
+ logger.error(
+ f"[Tools] Scheduler tool not loaded - missing dependency: {e}\n"
+ f" To enable scheduled tasks, run:\n"
+ f" pip install croniter>=2.0.0"
+ )
+ except Exception as e:
+ logger.error(f"[Tools] Scheduler tool failed to load: {e}")
- # Terminal (basic, should work)
- try:
- from agent.tools.terminal.terminal import Terminal
- tools['Terminal'] = Terminal
- except ImportError:
- pass
return tools
# Load optional tools
_optional_tools = _import_optional_tools()
+EnvConfig = _optional_tools.get('EnvConfig')
+SchedulerTool = _optional_tools.get('SchedulerTool')
GoogleSearch = _optional_tools.get('GoogleSearch')
FileSave = _optional_tools.get('FileSave')
Terminal = _optional_tools.get('Terminal')
@@ -74,28 +76,24 @@ def _import_browser_tool():
# Dynamically set BrowserTool
-BrowserTool = _import_browser_tool()
+# BrowserTool = _import_browser_tool()
# Export all tools (including optional ones that might be None)
__all__ = [
'BaseTool',
'ToolManager',
- 'Calculator',
'Read',
'Write',
'Edit',
'Bash',
- 'Grep',
- 'Find',
'Ls',
+ 'Send',
'MemorySearchTool',
'MemoryGetTool',
- 'WebFetch',
+ 'EnvConfig',
+ 'SchedulerTool',
# Optional tools (may be None if dependencies not available)
- 'GoogleSearch',
- 'FileSave',
- 'Terminal',
- 'BrowserTool'
+ # 'BrowserTool'
]
"""
diff --git a/agent/tools/bash/bash.py b/agent/tools/bash/bash.py
index e9b6ca0..4d7e564 100644
--- a/agent/tools/bash/bash.py
+++ b/agent/tools/bash/bash.py
@@ -3,12 +3,14 @@ Bash tool - Execute bash commands
"""
import os
+import sys
import subprocess
import tempfile
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
+from common.log import logger
class Bash(BaseTool):
@@ -60,6 +62,12 @@ IMPORTANT SAFETY GUIDELINES:
if not command:
return ToolResult.fail("Error: command parameter is required")
+ # Security check: Prevent accessing sensitive config files
+ if "~/.cow/.env" in command or "~/.cow" in command:
+ return ToolResult.fail(
+ "Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
+ )
+
# Optional safety check - only warn about extremely dangerous commands
if self.safety_mode:
warning = self._get_safety_warning(command)
@@ -68,7 +76,31 @@ IMPORTANT SAFETY GUIDELINES:
f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.")
try:
- # Execute command
+ # Prepare environment with .env file variables
+ env = os.environ.copy()
+
+ # Load environment variables from ~/.cow/.env if it exists
+ env_file = os.path.expanduser("~/.cow/.env")
+ if os.path.exists(env_file):
+ try:
+ from dotenv import dotenv_values
+ env_vars = dotenv_values(env_file)
+ env.update(env_vars)
+ logger.debug(f"[Bash] Loaded {len(env_vars)} variables from {env_file}")
+ except ImportError:
+ logger.debug("[Bash] python-dotenv not installed, skipping .env loading")
+ except Exception as e:
+ logger.debug(f"[Bash] Failed to load .env: {e}")
+
+ # Debug logging
+ logger.debug(f"[Bash] CWD: {self.cwd}")
+ logger.debug(f"[Bash] Command: {command[:500]}")
+ logger.debug(f"[Bash] OPENAI_API_KEY in env: {'OPENAI_API_KEY' in env}")
+ logger.debug(f"[Bash] SHELL: {env.get('SHELL', 'not set')}")
+ logger.debug(f"[Bash] Python executable: {sys.executable}")
+ logger.debug(f"[Bash] Process UID: {os.getuid()}")
+
+ # Execute command with inherited environment variables
result = subprocess.run(
command,
shell=True,
@@ -76,8 +108,50 @@ IMPORTANT SAFETY GUIDELINES:
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- timeout=timeout
+ timeout=timeout,
+ env=env
)
+
+ logger.debug(f"[Bash] Exit code: {result.returncode}")
+ logger.debug(f"[Bash] Stdout length: {len(result.stdout)}")
+ logger.debug(f"[Bash] Stderr length: {len(result.stderr)}")
+
+ # Workaround for exit code 126 with no output
+ if result.returncode == 126 and not result.stdout and not result.stderr:
+ logger.warning(f"[Bash] Exit 126 with no output - trying alternative execution method")
+ # Try using argument list instead of shell=True
+ import shlex
+ try:
+ parts = shlex.split(command)
+ if len(parts) > 0:
+ logger.info(f"[Bash] Retrying with argument list: {parts[:3]}...")
+ retry_result = subprocess.run(
+ parts,
+ cwd=self.cwd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ timeout=timeout,
+ env=env
+ )
+ logger.debug(f"[Bash] Retry exit code: {retry_result.returncode}, stdout: {len(retry_result.stdout)}, stderr: {len(retry_result.stderr)}")
+
+ # If retry succeeded, use retry result
+ if retry_result.returncode == 0 or retry_result.stdout or retry_result.stderr:
+ result = retry_result
+ else:
+ # Both attempts failed - check if this is openai-image-vision skill
+ if 'openai-image-vision' in command or 'vision.sh' in command:
+ # Create a mock result with helpful error message
+ from types import SimpleNamespace
+ result = SimpleNamespace(
+ returncode=1,
+ stdout='{"error": "图片无法解析", "reason": "该图片格式可能不受支持,或图片文件存在问题", "suggestion": "请尝试其他图片"}',
+ stderr=''
+ )
+ logger.info(f"[Bash] Converted exit 126 to user-friendly image error message for vision skill")
+ except Exception as retry_err:
+ logger.warning(f"[Bash] Retry failed: {retry_err}")
# Combine stdout and stderr
output = result.stdout
diff --git a/agent/tools/calculator/calculator.py b/agent/tools/calculator/calculator.py
deleted file mode 100644
index 092343d..0000000
--- a/agent/tools/calculator/calculator.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import math
-
-from agent.tools.base_tool import BaseTool, ToolResult
-
-
-class Calculator(BaseTool):
- name: str = "calculator"
- description: str = "A tool to perform basic mathematical calculations."
- params: dict = {
- "type": "object",
- "properties": {
- "expression": {
- "type": "string",
- "description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). "
- "Ensure your input is a valid Python expression, it will be evaluated directly."
- }
- },
- "required": ["expression"]
- }
- config: dict = {}
-
- def execute(self, args: dict) -> ToolResult:
- try:
- # Get the expression
- expression = args["expression"]
-
- # Create a safe local environment containing only basic math functions
- safe_locals = {
- "abs": abs,
- "round": round,
- "max": max,
- "min": min,
- "pow": pow,
- "sqrt": math.sqrt,
- "sin": math.sin,
- "cos": math.cos,
- "tan": math.tan,
- "pi": math.pi,
- "e": math.e,
- "log": math.log,
- "log10": math.log10,
- "exp": math.exp,
- "floor": math.floor,
- "ceil": math.ceil
- }
-
- # Safely evaluate the expression
- result = eval(expression, {"__builtins__": {}}, safe_locals)
-
- return ToolResult.success({
- "result": result,
- "expression": expression
- })
- except Exception as e:
- return ToolResult.success({
- "error": str(e),
- "expression": args.get("expression", "")
- })
diff --git a/agent/tools/edit/edit.py b/agent/tools/edit/edit.py
index 566309b..e17e624 100644
--- a/agent/tools/edit/edit.py
+++ b/agent/tools/edit/edit.py
@@ -22,7 +22,7 @@ class Edit(BaseTool):
"""Tool for precise file editing"""
name: str = "edit"
- description: str = "Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits."
+ description: str = "Edit a file by replacing exact text, or append to end if oldText is empty. For append: use empty oldText. For replace: oldText must match exactly (including whitespace)."
params: dict = {
"type": "object",
@@ -33,7 +33,7 @@ class Edit(BaseTool):
},
"oldText": {
"type": "string",
- "description": "Exact text to find and replace (must match exactly)"
+ "description": "Text to find and replace. Use empty string to append to end of file. For replacement: must match exactly including whitespace."
},
"newText": {
"type": "string",
@@ -89,34 +89,45 @@ class Edit(BaseTool):
normalized_old_text = normalize_to_lf(old_text)
normalized_new_text = normalize_to_lf(new_text)
- # Use fuzzy matching to find old text (try exact match first, then fuzzy match)
- match_result = fuzzy_find_text(normalized_content, normalized_old_text)
-
- if not match_result.found:
- return ToolResult.fail(
- f"Error: Could not find the exact text in {path}. "
- "The old text must match exactly including all whitespace and newlines."
+ # Special case: empty oldText means append to end of file
+ if not old_text or not old_text.strip():
+ # Append mode: add newText to the end
+ # Add newline before newText if file doesn't end with one
+ if normalized_content and not normalized_content.endswith('\n'):
+ new_content = normalized_content + '\n' + normalized_new_text
+ else:
+ new_content = normalized_content + normalized_new_text
+ base_content = normalized_content # For verification
+ else:
+ # Normal edit mode: find and replace
+ # Use fuzzy matching to find old text (try exact match first, then fuzzy match)
+ match_result = fuzzy_find_text(normalized_content, normalized_old_text)
+
+ if not match_result.found:
+ return ToolResult.fail(
+ f"Error: Could not find the exact text in {path}. "
+ "The old text must match exactly including all whitespace and newlines."
+ )
+
+ # Calculate occurrence count (use fuzzy normalized content for consistency)
+ fuzzy_content = normalize_for_fuzzy_match(normalized_content)
+ fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
+ occurrences = fuzzy_content.count(fuzzy_old_text)
+
+ if occurrences > 1:
+ return ToolResult.fail(
+ f"Error: Found {occurrences} occurrences of the text in {path}. "
+ "The text must be unique. Please provide more context to make it unique."
+ )
+
+ # Execute replacement (use matched text position)
+ base_content = match_result.content_for_replacement
+ new_content = (
+ base_content[:match_result.index] +
+ normalized_new_text +
+ base_content[match_result.index + match_result.match_length:]
)
- # Calculate occurrence count (use fuzzy normalized content for consistency)
- fuzzy_content = normalize_for_fuzzy_match(normalized_content)
- fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
- occurrences = fuzzy_content.count(fuzzy_old_text)
-
- if occurrences > 1:
- return ToolResult.fail(
- f"Error: Found {occurrences} occurrences of the text in {path}. "
- "The text must be unique. Please provide more context to make it unique."
- )
-
- # Execute replacement (use matched text position)
- base_content = match_result.content_for_replacement
- new_content = (
- base_content[:match_result.index] +
- normalized_new_text +
- base_content[match_result.index + match_result.match_length:]
- )
-
# Verify replacement actually changed content
if base_content == new_content:
return ToolResult.fail(
diff --git a/agent/tools/env_config/__init__.py b/agent/tools/env_config/__init__.py
new file mode 100644
index 0000000..2e5822f
--- /dev/null
+++ b/agent/tools/env_config/__init__.py
@@ -0,0 +1,3 @@
+from agent.tools.env_config.env_config import EnvConfig
+
+__all__ = ['EnvConfig']
diff --git a/agent/tools/env_config/env_config.py b/agent/tools/env_config/env_config.py
new file mode 100644
index 0000000..f0a10fe
--- /dev/null
+++ b/agent/tools/env_config/env_config.py
@@ -0,0 +1,284 @@
+"""
+Environment Configuration Tool - Manage API keys and environment variables
+"""
+
+import os
+import re
+from typing import Dict, Any
+from pathlib import Path
+
+from agent.tools.base_tool import BaseTool, ToolResult
+from common.log import logger
+
+
+# API Key 知识库:常见的环境变量及其描述
+API_KEY_REGISTRY = {
+ # AI 模型服务
+ "OPENAI_API_KEY": "OpenAI API 密钥 (用于GPT模型、Embedding模型)",
+ "GEMINI_API_KEY": "Google Gemini API 密钥",
+ "CLAUDE_API_KEY": "Claude API 密钥 (用于Claude模型)",
+ "LINKAI_API_KEY": "LinkAI智能体平台 API 密钥,支持多种模型切换",
+ # 搜索服务
+ "BOCHA_API_KEY": "博查 AI 搜索 API 密钥 ",
+}
+
+class EnvConfig(BaseTool):
+ """Tool for managing environment variables (API keys, etc.)"""
+
+ name: str = "env_config"
+ description: str = (
+ "Manage API keys and skill configurations securely. "
+ "Use this tool when user wants to configure API keys (like BOCHA_API_KEY, OPENAI_API_KEY), "
+ "view configured keys, or manage skill settings. "
+ "Actions: 'set' (add/update key), 'get' (view specific key), 'list' (show all configured keys), 'delete' (remove key). "
+ "Values are automatically masked for security. Changes take effect immediately via hot reload."
+ )
+
+ params: dict = {
+ "type": "object",
+ "properties": {
+ "action": {
+ "type": "string",
+ "description": "Action to perform: 'set', 'get', 'list', 'delete'",
+ "enum": ["set", "get", "list", "delete"]
+ },
+ "key": {
+ "type": "string",
+ "description": (
+ "Environment variable key name. Common keys:\n"
+ "- OPENAI_API_KEY: OpenAI API (GPT models)\n"
+ "- OPENAI_API_BASE: OpenAI API base URL\n"
+ "- CLAUDE_API_KEY: Anthropic Claude API\n"
+ "- GEMINI_API_KEY: Google Gemini API\n"
+ "- LINKAI_API_KEY: LinkAI platform\n"
+ "- BOCHA_API_KEY: Bocha AI search (博查搜索)\n"
+ "Use exact key names (case-sensitive, all uppercase with underscores)"
+ )
+ },
+ "value": {
+ "type": "string",
+ "description": "Value to set for the environment variable (for 'set' action)"
+ }
+ },
+ "required": ["action"]
+ }
+
+ def __init__(self, config: dict = None):
+ self.config = config or {}
+ # Store env config in ~/.cow directory (outside workspace for security)
+ self.env_dir = os.path.expanduser("~/.cow")
+ self.env_path = os.path.join(self.env_dir, '.env')
+ self.agent_bridge = self.config.get("agent_bridge") # Reference to AgentBridge for hot reload
+ # Don't create .env file in __init__ to avoid issues during tool discovery
+ # It will be created on first use in execute()
+
+ def _ensure_env_file(self):
+ """Ensure the .env file exists"""
+ # Create ~/.cow directory if it doesn't exist
+ os.makedirs(self.env_dir, exist_ok=True)
+
+ if not os.path.exists(self.env_path):
+ Path(self.env_path).touch()
+ logger.info(f"[EnvConfig] Created .env file at {self.env_path}")
+
+ def _mask_value(self, value: str) -> str:
+ """Mask sensitive parts of a value for logging"""
+ if not value or len(value) <= 10:
+ return "***"
+ return f"{value[:6]}***{value[-4:]}"
+
+ def _read_env_file(self) -> Dict[str, str]:
+ """Read all key-value pairs from .env file"""
+ env_vars = {}
+ if os.path.exists(self.env_path):
+ with open(self.env_path, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ # Skip empty lines and comments
+ if not line or line.startswith('#'):
+ continue
+ # Parse KEY=VALUE
+ match = re.match(r'^([^=]+)=(.*)$', line)
+ if match:
+ key, value = match.groups()
+ env_vars[key.strip()] = value.strip()
+ return env_vars
+
+ def _write_env_file(self, env_vars: Dict[str, str]):
+ """Write all key-value pairs to .env file"""
+ with open(self.env_path, 'w', encoding='utf-8') as f:
+ f.write("# Environment variables for agent skills\n")
+ f.write("# Auto-managed by env_config tool\n\n")
+ for key, value in sorted(env_vars.items()):
+ f.write(f"{key}={value}\n")
+
+ def _reload_env(self):
+ """Reload environment variables from .env file"""
+ env_vars = self._read_env_file()
+ for key, value in env_vars.items():
+ os.environ[key] = value
+ logger.debug(f"[EnvConfig] Reloaded {len(env_vars)} environment variables")
+
+ def _refresh_skills(self):
+ """Refresh skills after environment variable changes"""
+ if self.agent_bridge:
+ try:
+ # Reload .env file
+ self._reload_env()
+
+ # Refresh skills in all agent instances
+ refreshed = self.agent_bridge.refresh_all_skills()
+ logger.info(f"[EnvConfig] Refreshed skills in {refreshed} agent instance(s)")
+ return True
+ except Exception as e:
+ logger.warning(f"[EnvConfig] Failed to refresh skills: {e}")
+ return False
+ return False
+
+ def execute(self, args: Dict[str, Any]) -> ToolResult:
+ """
+ Execute environment configuration operation
+
+ :param args: Contains action, key, and value parameters
+ :return: Result of the operation
+ """
+ # Ensure .env file exists on first use
+ self._ensure_env_file()
+
+ action = args.get("action")
+ key = args.get("key")
+ value = args.get("value")
+
+ try:
+ if action == "set":
+ if not key or not value:
+ return ToolResult.fail("Error: 'key' and 'value' are required for 'set' action.")
+
+ # Read current env vars
+ env_vars = self._read_env_file()
+
+ # Update the key
+ env_vars[key] = value
+
+ # Write back to file
+ self._write_env_file(env_vars)
+
+ # Update current process env
+ os.environ[key] = value
+
+ logger.info(f"[EnvConfig] Set {key}={self._mask_value(value)}")
+
+ # Try to refresh skills immediately
+ refreshed = self._refresh_skills()
+
+ result = {
+ "message": f"Successfully set {key}",
+ "key": key,
+ "value": self._mask_value(value),
+ }
+
+ if refreshed:
+ result["note"] = "✅ Skills refreshed automatically - changes are now active"
+ else:
+ result["note"] = "⚠️ Skills not refreshed - restart agent to load new skills"
+
+ return ToolResult.success(result)
+
+ elif action == "get":
+ if not key:
+ return ToolResult.fail("Error: 'key' is required for 'get' action.")
+
+ # Check in file first, then in current env
+ env_vars = self._read_env_file()
+ value = env_vars.get(key) or os.getenv(key)
+
+ # Get description from registry
+ description = API_KEY_REGISTRY.get(key, "未知用途的环境变量")
+
+ if value is not None:
+ logger.info(f"[EnvConfig] Got {key}={self._mask_value(value)}")
+ return ToolResult.success({
+ "key": key,
+ "value": self._mask_value(value),
+ "description": description,
+ "exists": True
+ })
+ else:
+ return ToolResult.success({
+ "key": key,
+ "description": description,
+ "exists": False,
+ "message": f"Environment variable '{key}' is not set"
+ })
+
+ elif action == "list":
+ env_vars = self._read_env_file()
+
+ # Build detailed variable list with descriptions
+ variables_with_info = {}
+ for key, value in env_vars.items():
+ variables_with_info[key] = {
+ "value": self._mask_value(value),
+ "description": API_KEY_REGISTRY.get(key, "未知用途的环境变量")
+ }
+
+ logger.info(f"[EnvConfig] Listed {len(env_vars)} environment variables")
+
+ if not env_vars:
+ return ToolResult.success({
+ "message": "No environment variables configured",
+ "variables": {},
+ "note": "常用的 API 密钥可以通过 env_config(action='set', key='KEY_NAME', value='your-key') 来配置"
+ })
+
+ return ToolResult.success({
+ "message": f"Found {len(env_vars)} environment variable(s)",
+ "variables": variables_with_info
+ })
+
+ elif action == "delete":
+ if not key:
+ return ToolResult.fail("Error: 'key' is required for 'delete' action.")
+
+ # Read current env vars
+ env_vars = self._read_env_file()
+
+ if key not in env_vars:
+ return ToolResult.success({
+ "message": f"Environment variable '{key}' was not set",
+ "key": key
+ })
+
+ # Remove the key
+ del env_vars[key]
+
+ # Write back to file
+ self._write_env_file(env_vars)
+
+ # Remove from current process env
+ if key in os.environ:
+ del os.environ[key]
+
+ logger.info(f"[EnvConfig] Deleted {key}")
+
+ # Try to refresh skills immediately
+ refreshed = self._refresh_skills()
+
+ result = {
+ "message": f"Successfully deleted {key}",
+ "key": key,
+ }
+
+ if refreshed:
+ result["note"] = "✅ Skills refreshed automatically - changes are now active"
+ else:
+ result["note"] = "⚠️ Skills not refreshed - restart agent to apply changes"
+
+ return ToolResult.success(result)
+
+ else:
+ return ToolResult.fail(f"Error: Unknown action '{action}'. Use 'set', 'get', 'list', or 'delete'.")
+
+ except Exception as e:
+ logger.error(f"[EnvConfig] Error: {e}", exc_info=True)
+ return ToolResult.fail(f"EnvConfig tool error: {str(e)}")
diff --git a/agent/tools/find/__init__.py b/agent/tools/find/__init__.py
deleted file mode 100644
index f2af14f..0000000
--- a/agent/tools/find/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .find import Find
-
-__all__ = ['Find']
diff --git a/agent/tools/find/find.py b/agent/tools/find/find.py
deleted file mode 100644
index 7a2c4a1..0000000
--- a/agent/tools/find/find.py
+++ /dev/null
@@ -1,177 +0,0 @@
-"""
-Find tool - Search for files by glob pattern
-"""
-
-import os
-import glob as glob_module
-from typing import Dict, Any, List
-
-from agent.tools.base_tool import BaseTool, ToolResult
-from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
-
-
-DEFAULT_LIMIT = 1000
-
-
-class Find(BaseTool):
- """Tool for finding files by pattern"""
-
- name: str = "find"
- description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
-
- params: dict = {
- "type": "object",
- "properties": {
- "pattern": {
- "type": "string",
- "description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'"
- },
- "path": {
- "type": "string",
- "description": "Directory to search in (default: current directory)"
- },
- "limit": {
- "type": "integer",
- "description": f"Maximum number of results (default: {DEFAULT_LIMIT})"
- }
- },
- "required": ["pattern"]
- }
-
- def __init__(self, config: dict = None):
- self.config = config or {}
- self.cwd = self.config.get("cwd", os.getcwd())
-
- def execute(self, args: Dict[str, Any]) -> ToolResult:
- """
- Execute file search
-
- :param args: Search parameters
- :return: Search results or error
- """
- pattern = args.get("pattern", "").strip()
- search_path = args.get("path", ".").strip()
- limit = args.get("limit", DEFAULT_LIMIT)
-
- if not pattern:
- return ToolResult.fail("Error: pattern parameter is required")
-
- # Resolve search path
- absolute_path = self._resolve_path(search_path)
-
- if not os.path.exists(absolute_path):
- return ToolResult.fail(f"Error: Path not found: {search_path}")
-
- if not os.path.isdir(absolute_path):
- return ToolResult.fail(f"Error: Not a directory: {search_path}")
-
- try:
- # Load .gitignore patterns
- ignore_patterns = self._load_gitignore(absolute_path)
-
- # Search for files
- results = []
- search_pattern = os.path.join(absolute_path, pattern)
-
- # Use glob with recursive support
- for file_path in glob_module.glob(search_pattern, recursive=True):
- # Skip if matches ignore patterns
- if self._should_ignore(file_path, absolute_path, ignore_patterns):
- continue
-
- # Get relative path
- relative_path = os.path.relpath(file_path, absolute_path)
-
- # Add trailing slash for directories
- if os.path.isdir(file_path):
- relative_path += '/'
-
- results.append(relative_path)
-
- if len(results) >= limit:
- break
-
- if not results:
- return ToolResult.success({"message": "No files found matching pattern", "files": []})
-
- # Sort results
- results.sort()
-
- # Format output
- raw_output = '\n'.join(results)
- truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
-
- output = truncation.content
- details = {}
- notices = []
-
- result_limit_reached = len(results) >= limit
- if result_limit_reached:
- notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern")
- details["result_limit_reached"] = limit
-
- if truncation.truncated:
- notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
- details["truncation"] = truncation.to_dict()
-
- if notices:
- output += f"\n\n[{'. '.join(notices)}]"
-
- return ToolResult.success({
- "output": output,
- "file_count": len(results),
- "details": details if details else None
- })
-
- except Exception as e:
- return ToolResult.fail(f"Error executing find: {str(e)}")
-
- def _resolve_path(self, path: str) -> str:
- """Resolve path to absolute path"""
- # Expand ~ to user home directory
- path = os.path.expanduser(path)
- if os.path.isabs(path):
- return path
- return os.path.abspath(os.path.join(self.cwd, path))
-
- def _load_gitignore(self, directory: str) -> List[str]:
- """Load .gitignore patterns from directory"""
- patterns = []
- gitignore_path = os.path.join(directory, '.gitignore')
-
- if os.path.exists(gitignore_path):
- try:
- with open(gitignore_path, 'r', encoding='utf-8') as f:
- for line in f:
- line = line.strip()
- if line and not line.startswith('#'):
- patterns.append(line)
- except:
- pass
-
- # Add common ignore patterns
- patterns.extend([
- '.git',
- '__pycache__',
- '*.pyc',
- 'node_modules',
- '.DS_Store'
- ])
-
- return patterns
-
- def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool:
- """Check if file should be ignored based on patterns"""
- relative_path = os.path.relpath(file_path, base_path)
-
- for pattern in patterns:
- # Simple pattern matching
- if pattern in relative_path:
- return True
-
- # Check if it's a directory pattern
- if pattern.endswith('/'):
- if relative_path.startswith(pattern.rstrip('/')):
- return True
-
- return False
diff --git a/agent/tools/grep/__init__.py b/agent/tools/grep/__init__.py
deleted file mode 100644
index e4d57b0..0000000
--- a/agent/tools/grep/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .grep import Grep
-
-__all__ = ['Grep']
diff --git a/agent/tools/grep/grep.py b/agent/tools/grep/grep.py
deleted file mode 100644
index 1e7d95e..0000000
--- a/agent/tools/grep/grep.py
+++ /dev/null
@@ -1,248 +0,0 @@
-"""
-Grep tool - Search file contents for patterns
-Uses ripgrep (rg) for fast searching
-"""
-
-import os
-import re
-import subprocess
-import json
-from typing import Dict, Any, List, Optional
-
-from agent.tools.base_tool import BaseTool, ToolResult
-from agent.tools.utils.truncate import (
- truncate_head, truncate_line, format_size,
- DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH
-)
-
-
-DEFAULT_LIMIT = 100
-
-
-class Grep(BaseTool):
- """Tool for searching file contents"""
-
- name: str = "grep"
- description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars."
-
- params: dict = {
- "type": "object",
- "properties": {
- "pattern": {
- "type": "string",
- "description": "Search pattern (regex or literal string)"
- },
- "path": {
- "type": "string",
- "description": "Directory or file to search (default: current directory)"
- },
- "glob": {
- "type": "string",
- "description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'"
- },
- "ignoreCase": {
- "type": "boolean",
- "description": "Case-insensitive search (default: false)"
- },
- "literal": {
- "type": "boolean",
- "description": "Treat pattern as literal string instead of regex (default: false)"
- },
- "context": {
- "type": "integer",
- "description": "Number of lines to show before and after each match (default: 0)"
- },
- "limit": {
- "type": "integer",
- "description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})"
- }
- },
- "required": ["pattern"]
- }
-
- def __init__(self, config: dict = None):
- self.config = config or {}
- self.cwd = self.config.get("cwd", os.getcwd())
- self.rg_path = self._find_ripgrep()
-
- def _find_ripgrep(self) -> Optional[str]:
- """Find ripgrep executable"""
- try:
- result = subprocess.run(['which', 'rg'], capture_output=True, text=True)
- if result.returncode == 0:
- return result.stdout.strip()
- except:
- pass
- return None
-
- def execute(self, args: Dict[str, Any]) -> ToolResult:
- """
- Execute grep search
-
- :param args: Search parameters
- :return: Search results or error
- """
- if not self.rg_path:
- return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.")
-
- pattern = args.get("pattern", "").strip()
- search_path = args.get("path", ".").strip()
- glob = args.get("glob")
- ignore_case = args.get("ignoreCase", False)
- literal = args.get("literal", False)
- context = args.get("context", 0)
- limit = args.get("limit", DEFAULT_LIMIT)
-
- if not pattern:
- return ToolResult.fail("Error: pattern parameter is required")
-
- # Resolve search path
- absolute_path = self._resolve_path(search_path)
-
- if not os.path.exists(absolute_path):
- return ToolResult.fail(f"Error: Path not found: {search_path}")
-
- # Build ripgrep command
- cmd = [
- self.rg_path,
- '--json',
- '--line-number',
- '--color=never',
- '--hidden'
- ]
-
- if ignore_case:
- cmd.append('--ignore-case')
-
- if literal:
- cmd.append('--fixed-strings')
-
- if glob:
- cmd.extend(['--glob', glob])
-
- cmd.extend([pattern, absolute_path])
-
- try:
- # Execute ripgrep
- result = subprocess.run(
- cmd,
- cwd=self.cwd,
- capture_output=True,
- text=True,
- timeout=30
- )
-
- # Parse JSON output
- matches = []
- match_count = 0
-
- for line in result.stdout.splitlines():
- if not line.strip():
- continue
-
- try:
- event = json.loads(line)
- if event.get('type') == 'match':
- data = event.get('data', {})
- file_path = data.get('path', {}).get('text')
- line_number = data.get('line_number')
-
- if file_path and line_number:
- matches.append({
- 'file': file_path,
- 'line': line_number
- })
- match_count += 1
-
- if match_count >= limit:
- break
- except json.JSONDecodeError:
- continue
-
- if match_count == 0:
- return ToolResult.success({"message": "No matches found", "matches": []})
-
- # Format output with context
- output_lines = []
- lines_truncated = False
- is_directory = os.path.isdir(absolute_path)
-
- for match in matches:
- file_path = match['file']
- line_number = match['line']
-
- # Format file path
- if is_directory:
- relative_path = os.path.relpath(file_path, absolute_path)
- else:
- relative_path = os.path.basename(file_path)
-
- # Read file and get context
- try:
- with open(file_path, 'r', encoding='utf-8') as f:
- file_lines = f.read().split('\n')
-
- # Calculate context range
- start = max(0, line_number - 1 - context) if context > 0 else line_number - 1
- end = min(len(file_lines), line_number + context) if context > 0 else line_number
-
- # Format lines with context
- for i in range(start, end):
- line_text = file_lines[i].replace('\r', '')
-
- # Truncate long lines
- truncated_text, was_truncated = truncate_line(line_text)
- if was_truncated:
- lines_truncated = True
-
- # Format output
- current_line = i + 1
- if current_line == line_number:
- output_lines.append(f"{relative_path}:{current_line}: {truncated_text}")
- else:
- output_lines.append(f"{relative_path}-{current_line}- {truncated_text}")
-
- except Exception:
- output_lines.append(f"{relative_path}:{line_number}: (unable to read file)")
-
- # Apply byte truncation
- raw_output = '\n'.join(output_lines)
- truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
-
- output = truncation.content
- details = {}
- notices = []
-
- if match_count >= limit:
- notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern")
- details["match_limit_reached"] = limit
-
- if truncation.truncated:
- notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
- details["truncation"] = truncation.to_dict()
-
- if lines_truncated:
- notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines")
- details["lines_truncated"] = True
-
- if notices:
- output += f"\n\n[{'. '.join(notices)}]"
-
- return ToolResult.success({
- "output": output,
- "match_count": match_count,
- "details": details if details else None
- })
-
- except subprocess.TimeoutExpired:
- return ToolResult.fail("Error: Search timed out after 30 seconds")
- except Exception as e:
- return ToolResult.fail(f"Error executing grep: {str(e)}")
-
- def _resolve_path(self, path: str) -> str:
- """Resolve path to absolute path"""
- # Expand ~ to user home directory
- path = os.path.expanduser(path)
- if os.path.isabs(path):
- return path
- return os.path.abspath(os.path.join(self.cwd, path))
diff --git a/agent/tools/ls/ls.py b/agent/tools/ls/ls.py
index d3e5330..d6517b3 100644
--- a/agent/tools/ls/ls.py
+++ b/agent/tools/ls/ls.py
@@ -50,6 +50,13 @@ class Ls(BaseTool):
# Resolve path
absolute_path = self._resolve_path(path)
+ # Security check: Prevent accessing sensitive config directory
+ env_config_dir = os.path.expanduser("~/.cow")
+ if os.path.abspath(absolute_path) == os.path.abspath(env_config_dir):
+ return ToolResult.fail(
+ "Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
+ )
+
if not os.path.exists(absolute_path):
# Provide helpful hint if using relative path
if not os.path.isabs(path) and not path.startswith('~'):
diff --git a/agent/tools/memory/memory_get.py b/agent/tools/memory/memory_get.py
index d828386..5febb10 100644
--- a/agent/tools/memory/memory_get.py
+++ b/agent/tools/memory/memory_get.py
@@ -4,8 +4,6 @@ Memory get tool
Allows agents to read specific sections from memory files
"""
-from typing import Dict, Any
-from pathlib import Path
from agent.tools.base_tool import BaseTool
@@ -22,7 +20,7 @@ class MemoryGetTool(BaseTool):
"properties": {
"path": {
"type": "string",
- "description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')"
+ "description": "Relative path to the memory file (e.g. 'MEMORY.md', 'memory/2026-01-01.md')"
},
"start_line": {
"type": "integer",
@@ -70,7 +68,8 @@ class MemoryGetTool(BaseTool):
workspace_dir = self.memory_manager.config.get_workspace()
# Auto-prepend memory/ if not present and not absolute path
- if not path.startswith('memory/') and not path.startswith('/'):
+ # Exception: MEMORY.md is in the root directory
+ if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
path = f'memory/{path}'
file_path = workspace_dir / path
diff --git a/agent/tools/read/read.py b/agent/tools/read/read.py
index 4810890..f88bc50 100644
--- a/agent/tools/read/read.py
+++ b/agent/tools/read/read.py
@@ -15,7 +15,7 @@ class Read(BaseTool):
"""Tool for reading file contents"""
name: str = "read"
- description: str = f"Read the contents of a file. Supports text files, PDF files, and images (jpg, png, gif, webp). For text files, output is truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Use offset/limit for large files."
+ description: str = f"Read or inspect file contents. For text/PDF files, returns content (truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB). For images/videos/audio, returns metadata only (file info, size, type). Use offset/limit for large text files."
params: dict = {
"type": "object",
@@ -26,7 +26,7 @@ class Read(BaseTool):
},
"offset": {
"type": "integer",
- "description": "Line number to start reading from (1-indexed, optional)"
+ "description": "Line number to start reading from (1-indexed, optional). Use negative values to read from end (e.g. -20 for last 20 lines)"
},
"limit": {
"type": "integer",
@@ -39,10 +39,25 @@ class Read(BaseTool):
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
- # Supported image formats
- self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp'}
- # Supported PDF format
+
+ # File type categories
+ self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
+ self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
+ self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
+ self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'}
+ self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'}
self.pdf_extensions = {'.pdf'}
+
+ # Readable text formats (will be read with truncation)
+ self.text_extensions = {
+ '.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml',
+ '.py', '.js', '.ts', '.java', '.c', '.cpp', '.h', '.hpp', '.go', '.rs', '.rb', '.php',
+ '.html', '.css', '.scss', '.sass', '.less', '.vue', '.jsx', '.tsx',
+ '.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd',
+ '.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex',
+ '.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg',
+ '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx' # Office documents
+ }
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
@@ -61,6 +76,13 @@ class Read(BaseTool):
# Resolve path
absolute_path = self._resolve_path(path)
+ # Security check: Prevent reading sensitive config files
+ env_config_path = os.path.expanduser("~/.cow/.env")
+ if os.path.abspath(absolute_path) == os.path.abspath(env_config_path):
+ return ToolResult.fail(
+ "Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
+ )
+
# Check if file exists
if not os.path.exists(absolute_path):
# Provide helpful hint if using relative path
@@ -78,16 +100,25 @@ class Read(BaseTool):
# Check file type
file_ext = Path(absolute_path).suffix.lower()
+ file_size = os.path.getsize(absolute_path)
- # Check if image
+ # Check if image - return metadata for sending
if file_ext in self.image_extensions:
return self._read_image(absolute_path, file_ext)
+ # Check if video/audio/binary/archive - return metadata only
+ if file_ext in self.video_extensions:
+ return self._return_file_metadata(absolute_path, "video", file_size)
+ if file_ext in self.audio_extensions:
+ return self._return_file_metadata(absolute_path, "audio", file_size)
+ if file_ext in self.binary_extensions or file_ext in self.archive_extensions:
+ return self._return_file_metadata(absolute_path, "binary", file_size)
+
# Check if PDF
if file_ext in self.pdf_extensions:
return self._read_pdf(absolute_path, path, offset, limit)
- # Read text file
+ # Read text file (with truncation for large files)
return self._read_text(absolute_path, path, offset, limit)
def _resolve_path(self, path: str) -> str:
@@ -103,25 +134,56 @@ class Read(BaseTool):
return path
return os.path.abspath(os.path.join(self.cwd, path))
+ def _return_file_metadata(self, absolute_path: str, file_type: str, file_size: int) -> ToolResult:
+ """
+ Return file metadata for non-readable files (video, audio, binary, etc.)
+
+ :param absolute_path: Absolute path to the file
+ :param file_type: Type of file (video, audio, binary, etc.)
+ :param file_size: File size in bytes
+ :return: File metadata
+ """
+ file_name = Path(absolute_path).name
+ file_ext = Path(absolute_path).suffix.lower()
+
+ # Determine MIME type
+ mime_types = {
+ # Video
+ '.mp4': 'video/mp4', '.avi': 'video/x-msvideo', '.mov': 'video/quicktime',
+ '.mkv': 'video/x-matroska', '.webm': 'video/webm',
+ # Audio
+ '.mp3': 'audio/mpeg', '.wav': 'audio/wav', '.ogg': 'audio/ogg',
+ '.m4a': 'audio/mp4', '.flac': 'audio/flac',
+ # Binary
+ '.zip': 'application/zip', '.tar': 'application/x-tar',
+ '.gz': 'application/gzip', '.rar': 'application/x-rar-compressed',
+ }
+ mime_type = mime_types.get(file_ext, 'application/octet-stream')
+
+ result = {
+ "type": f"{file_type}_metadata",
+ "file_type": file_type,
+ "path": absolute_path,
+ "file_name": file_name,
+ "mime_type": mime_type,
+ "size": file_size,
+ "size_formatted": format_size(file_size),
+ "message": f"{file_type.capitalize()} 文件: {file_name} ({format_size(file_size)})\n提示: 如果需要发送此文件,请使用 send 工具。"
+ }
+
+ return ToolResult.success(result)
+
def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult:
"""
- Read image file
+ Read image file - always return metadata only (images should be sent, not read into context)
:param absolute_path: Absolute path to the image file
:param file_ext: File extension
- :return: Result containing image information
+ :return: Result containing image metadata for sending
"""
try:
- # Read image file
- with open(absolute_path, 'rb') as f:
- image_data = f.read()
-
# Get file size
- file_size = len(image_data)
-
- # Return image information (actual image data can be base64 encoded when needed)
- import base64
- base64_data = base64.b64encode(image_data).decode('utf-8')
+ file_size = os.path.getsize(absolute_path)
# Determine MIME type
mime_type_map = {
@@ -133,12 +195,15 @@ class Read(BaseTool):
}
mime_type = mime_type_map.get(file_ext, 'image/jpeg')
+ # Return metadata for images (NOT file_to_send - use send tool to actually send)
result = {
- "type": "image",
+ "type": "image_metadata",
+ "file_type": "image",
+ "path": absolute_path,
"mime_type": mime_type,
"size": file_size,
"size_formatted": format_size(file_size),
- "data": base64_data # Base64 encoded image data
+ "message": f"图片文件: {Path(absolute_path).name} ({format_size(file_size)})\n提示: 如果需要发送此图片,请使用 send 工具。"
}
return ToolResult.success(result)
@@ -157,21 +222,49 @@ class Read(BaseTool):
:return: File content or error message
"""
try:
+ # Check file size first
+ file_size = os.path.getsize(absolute_path)
+ MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
+
+ if file_size > MAX_FILE_SIZE:
+ # File too large, return metadata only
+ return ToolResult.success({
+ "type": "file_to_send",
+ "file_type": "document",
+ "path": absolute_path,
+ "size": file_size,
+ "size_formatted": format_size(file_size),
+ "message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}"
+ })
+
# Read file
with open(absolute_path, 'r', encoding='utf-8') as f:
content = f.read()
+ # Truncate content if too long (20K characters max for model context)
+ MAX_CONTENT_CHARS = 20 * 1024 # 20K characters
+ content_truncated = False
+ if len(content) > MAX_CONTENT_CHARS:
+ content = content[:MAX_CONTENT_CHARS]
+ content_truncated = True
+
all_lines = content.split('\n')
total_file_lines = len(all_lines)
# Apply offset (if specified)
start_line = 0
if offset is not None:
- start_line = max(0, offset - 1) # Convert to 0-indexed
- if start_line >= total_file_lines:
- return ToolResult.fail(
- f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
- )
+ if offset < 0:
+ # Negative offset: read from end
+ # -20 means "last 20 lines" → start from (total - 20)
+ start_line = max(0, total_file_lines + offset)
+ else:
+ # Positive offset: read from start (1-indexed)
+ start_line = max(0, offset - 1) # Convert to 0-indexed
+ if start_line >= total_file_lines:
+ return ToolResult.fail(
+ f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
+ )
start_line_display = start_line + 1 # For display (1-indexed)
@@ -191,6 +284,10 @@ class Read(BaseTool):
output_text = ""
details = {}
+ # Add truncation warning if content was truncated
+ if content_truncated:
+ output_text = f"[文件内容已截断到前 {format_size(MAX_CONTENT_CHARS)},完整文件大小: {format_size(file_size)}]\n\n"
+
if truncation.first_line_exceeds_limit:
# First line exceeds 30KB limit
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
diff --git a/agent/tools/scheduler/README.md b/agent/tools/scheduler/README.md
new file mode 100644
index 0000000..55be2f9
--- /dev/null
+++ b/agent/tools/scheduler/README.md
@@ -0,0 +1,287 @@
+# 定时任务工具 (Scheduler Tool)
+
+## 功能简介
+
+定时任务工具允许 Agent 创建、管理和执行定时任务,支持:
+
+- ⏰ **定时提醒**: 在指定时间发送消息
+- 🔄 **周期性任务**: 按固定间隔或 cron 表达式重复执行
+- 🔧 **动态工具调用**: 定时执行其他工具并发送结果(如搜索新闻、查询天气等)
+- 📋 **任务管理**: 查询、启用、禁用、删除任务
+
+## 安装依赖
+
+```bash
+pip install croniter>=2.0.0
+```
+
+## 使用方法
+
+### 1. 创建定时任务
+
+Agent 可以通过自然语言创建定时任务,支持两种类型:
+
+#### 1.1 静态消息任务
+
+发送预定义的消息:
+
+**示例对话:**
+```
+用户: 每天早上9点提醒我开会
+Agent: [调用 scheduler 工具]
+ action: create
+ name: 每日开会提醒
+ message: 该开会了!
+ schedule_type: cron
+ schedule_value: 0 9 * * *
+```
+
+#### 1.2 动态工具调用任务
+
+定时执行工具并发送结果:
+
+**示例对话:**
+```
+用户: 每天早上8点帮我读取一下今日日程
+Agent: [调用 scheduler 工具]
+ action: create
+ name: 每日日程
+ tool_call:
+ tool_name: read
+ tool_params:
+ file_path: ~/cow/schedule.txt
+ result_prefix: 📅 今日日程
+ schedule_type: cron
+ schedule_value: 0 8 * * *
+```
+
+**工具调用参数说明:**
+- `tool_name`: 要调用的工具名称(如 `bash`、`read`、`write` 等内置工具)
+- `tool_params`: 工具的参数(字典格式)
+- `result_prefix`: 可选,在结果前添加的前缀文本
+
+**注意:** 如果要使用 skills(如 bocha-search),需要通过 `bash` 工具调用 skill 脚本
+
+### 2. 支持的调度类型
+
+#### Cron 表达式 (`cron`)
+使用标准 cron 表达式:
+
+```
+0 9 * * * # 每天 9:00
+0 */2 * * * # 每 2 小时
+30 8 * * 1-5 # 工作日 8:30
+0 0 1 * * # 每月 1 号
+```
+
+#### 固定间隔 (`interval`)
+以秒为单位的间隔:
+
+```
+3600 # 每小时
+86400 # 每天
+1800 # 每 30 分钟
+```
+
+#### 一次性任务 (`once`)
+指定具体时间(ISO 格式):
+
+```
+2024-12-25T09:00:00
+2024-12-31T23:59:59
+```
+
+### 3. 查询任务列表
+
+```
+用户: 查看我的定时任务
+Agent: [调用 scheduler 工具]
+ action: list
+```
+
+### 4. 查看任务详情
+
+```
+用户: 查看任务 abc123 的详情
+Agent: [调用 scheduler 工具]
+ action: get
+ task_id: abc123
+```
+
+### 5. 删除任务
+
+```
+用户: 删除任务 abc123
+Agent: [调用 scheduler 工具]
+ action: delete
+ task_id: abc123
+```
+
+### 6. 启用/禁用任务
+
+```
+用户: 暂停任务 abc123
+Agent: [调用 scheduler 工具]
+ action: disable
+ task_id: abc123
+
+用户: 恢复任务 abc123
+Agent: [调用 scheduler 工具]
+ action: enable
+ task_id: abc123
+```
+
+## 任务存储
+
+任务保存在 JSON 文件中:
+```
+~/cow/scheduler/tasks.json
+```
+
+任务数据结构:
+
+**静态消息任务:**
+```json
+{
+ "id": "abc123",
+ "name": "每日提醒",
+ "enabled": true,
+ "created_at": "2024-01-01T10:00:00",
+ "updated_at": "2024-01-01T10:00:00",
+ "schedule": {
+ "type": "cron",
+ "expression": "0 9 * * *"
+ },
+ "action": {
+ "type": "send_message",
+ "content": "该开会了!",
+ "receiver": "wxid_xxx",
+ "receiver_name": "张三",
+ "is_group": false,
+ "channel_type": "wechat"
+ },
+ "next_run_at": "2024-01-02T09:00:00",
+ "last_run_at": "2024-01-01T09:00:00"
+}
+```
+
+**动态工具调用任务:**
+```json
+{
+ "id": "def456",
+ "name": "每日日程",
+ "enabled": true,
+ "created_at": "2024-01-01T10:00:00",
+ "updated_at": "2024-01-01T10:00:00",
+ "schedule": {
+ "type": "cron",
+ "expression": "0 8 * * *"
+ },
+ "action": {
+ "type": "tool_call",
+ "tool_name": "read",
+ "tool_params": {
+ "file_path": "~/cow/schedule.txt"
+ },
+ "result_prefix": "📅 今日日程",
+ "receiver": "wxid_xxx",
+ "receiver_name": "张三",
+ "is_group": false,
+ "channel_type": "wechat"
+ },
+ "next_run_at": "2024-01-02T08:00:00"
+}
+```
+
+## 后台服务
+
+定时任务由后台服务 `SchedulerService` 管理:
+
+- 每 30 秒检查一次到期任务
+- 自动执行到期任务
+- 计算下次执行时间
+- 记录执行历史和错误
+
+服务在 Agent 初始化时自动启动,无需手动配置。
+
+## 接收者确定
+
+定时任务会发送给**创建任务时的对话对象**:
+
+- 如果在私聊中创建,发送给该用户
+- 如果在群聊中创建,发送到该群
+- 接收者信息在创建时自动保存
+
+## 常见用例
+
+### 1. 每日提醒(静态消息)
+```
+用户: 每天早上8点提醒我吃药
+Agent: ✅ 定时任务创建成功
+ 任务ID: a1b2c3d4
+ 调度: 每天 8:00
+ 消息: 该吃药了!
+```
+
+### 2. 工作日提醒(静态消息)
+```
+用户: 工作日下午6点提醒我下班
+Agent: [创建 cron: 0 18 * * 1-5]
+ 消息: 该下班了!
+```
+
+### 3. 倒计时提醒(静态消息)
+```
+用户: 1小时后提醒我
+Agent: [创建 interval: 3600]
+```
+
+### 4. 每日日程推送(动态工具调用)
+```
+用户: 每天早上8点帮我读取今日日程
+Agent: ✅ 定时任务创建成功
+ 任务ID: schedule001
+ 调度: 每天 8:00
+ 工具: read(file_path='~/cow/schedule.txt')
+ 前缀: 📅 今日日程
+```
+
+### 5. 定时文件备份(动态工具调用)
+```
+用户: 每天晚上11点备份工作文件
+Agent: [创建 cron: 0 23 * * *]
+ 工具: bash(command='cp ~/cow/work.txt ~/cow/backup/work_$(date +%Y%m%d).txt')
+ 前缀: ✅ 文件已备份
+```
+
+### 6. 周报提醒(静态消息)
+```
+用户: 每周五下午5点提醒我写周报
+Agent: [创建 cron: 0 17 * * 5]
+ 消息: 📊 该写周报了!
+```
+
+### 4. 特定日期提醒
+```
+用户: 12月25日早上9点提醒我圣诞快乐
+Agent: [创建 once: 2024-12-25T09:00:00]
+```
+
+## 注意事项
+
+1. **时区**: 使用系统本地时区
+2. **精度**: 检查间隔为 30 秒,实际执行可能有 ±30 秒误差
+3. **持久化**: 任务保存在文件中,重启后自动恢复
+4. **一次性任务**: 执行后自动禁用,不会删除(可手动删除)
+5. **错误处理**: 执行失败会记录错误,不影响其他任务
+
+## 技术实现
+
+- **TaskStore**: 任务持久化存储
+- **SchedulerService**: 后台调度服务
+- **SchedulerTool**: Agent 工具接口
+- **Integration**: 与 AgentBridge 集成
+
+## 依赖
+
+- `croniter`: Cron 表达式解析(轻量级,仅 ~50KB)
diff --git a/agent/tools/scheduler/__init__.py b/agent/tools/scheduler/__init__.py
new file mode 100644
index 0000000..dafc8b6
--- /dev/null
+++ b/agent/tools/scheduler/__init__.py
@@ -0,0 +1,7 @@
+"""
+Scheduler tool for managing scheduled tasks
+"""
+
+from .scheduler_tool import SchedulerTool
+
+__all__ = ["SchedulerTool"]
diff --git a/agent/tools/scheduler/integration.py b/agent/tools/scheduler/integration.py
new file mode 100644
index 0000000..1882195
--- /dev/null
+++ b/agent/tools/scheduler/integration.py
@@ -0,0 +1,447 @@
+"""
+Integration module for scheduler with AgentBridge
+"""
+
+import os
+from typing import Optional
+from config import conf
+from common.log import logger
+from bridge.context import Context, ContextType
+from bridge.reply import Reply, ReplyType
+
+# Global scheduler service instance
+_scheduler_service = None
+_task_store = None
+
+
+def init_scheduler(agent_bridge) -> bool:
+ """
+ Initialize scheduler service
+
+ Args:
+ agent_bridge: AgentBridge instance
+
+ Returns:
+ True if initialized successfully
+ """
+ global _scheduler_service, _task_store
+
+ try:
+ from agent.tools.scheduler.task_store import TaskStore
+ from agent.tools.scheduler.scheduler_service import SchedulerService
+
+ # Get workspace from config
+ workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
+ store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
+
+ # Create task store
+ _task_store = TaskStore(store_path)
+ logger.debug(f"[Scheduler] Task store initialized: {store_path}")
+
+ # Create execute callback
+ def execute_task_callback(task: dict):
+ """Callback to execute a scheduled task"""
+ try:
+ action = task.get("action", {})
+ action_type = action.get("type")
+
+ if action_type == "agent_task":
+ _execute_agent_task(task, agent_bridge)
+ elif action_type == "send_message":
+ # Legacy support for old tasks
+ _execute_send_message(task, agent_bridge)
+ elif action_type == "tool_call":
+ # Legacy support for old tasks
+ _execute_tool_call(task, agent_bridge)
+ elif action_type == "skill_call":
+ # Legacy support for old tasks
+ _execute_skill_call(task, agent_bridge)
+ else:
+ logger.warning(f"[Scheduler] Unknown action type: {action_type}")
+ except Exception as e:
+ logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
+
+ # Create scheduler service
+ _scheduler_service = SchedulerService(_task_store, execute_task_callback)
+ _scheduler_service.start()
+
+ logger.debug("[Scheduler] Scheduler service initialized and started")
+ return True
+
+ except Exception as e:
+ logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
+ return False
+
+
+def get_task_store():
+ """Get the global task store instance"""
+ return _task_store
+
+
+def get_scheduler_service():
+ """Get the global scheduler service instance"""
+ return _scheduler_service
+
+
+def _execute_agent_task(task: dict, agent_bridge):
+ """
+ Execute an agent_task action - let Agent handle the task
+
+ Args:
+ task: Task dictionary
+ agent_bridge: AgentBridge instance
+ """
+ try:
+ action = task.get("action", {})
+ task_description = action.get("task_description")
+ receiver = action.get("receiver")
+ is_group = action.get("is_group", False)
+ channel_type = action.get("channel_type", "unknown")
+
+ if not task_description:
+ logger.error(f"[Scheduler] Task {task['id']}: No task_description specified")
+ return
+
+ if not receiver:
+ logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
+ return
+
+ # Check for unsupported channels
+ if channel_type == "dingtalk":
+ logger.warning(f"[Scheduler] Task {task['id']}: DingTalk channel does not support scheduled messages (Stream mode limitation). Task will execute but message cannot be sent.")
+
+ logger.info(f"[Scheduler] Task {task['id']}: Executing agent task '{task_description}'")
+
+ # Create context for Agent
+ context = Context(ContextType.TEXT, task_description)
+ context["receiver"] = receiver
+ context["isgroup"] = is_group
+ context["session_id"] = receiver
+
+ # Channel-specific setup
+ if channel_type == "web":
+ import uuid
+ request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
+ context["request_id"] = request_id
+ elif channel_type == "feishu":
+ context["receive_id_type"] = "chat_id" if is_group else "open_id"
+ context["msg"] = None
+ elif channel_type == "dingtalk":
+ # DingTalk requires msg object, set to None for scheduled tasks
+ context["msg"] = None
+ # 如果是单聊,需要传递 sender_staff_id
+ if not is_group:
+ sender_staff_id = action.get("dingtalk_sender_staff_id")
+ if sender_staff_id:
+ context["dingtalk_sender_staff_id"] = sender_staff_id
+
+ # Use Agent to execute the task
+ # Mark this as a scheduled task execution to prevent recursive task creation
+ context["is_scheduled_task"] = True
+
+ try:
+ reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=True)
+
+ if reply and reply.content:
+ # Send the reply via channel
+ from channel.channel_factory import create_channel
+
+ try:
+ channel = create_channel(channel_type)
+ if channel:
+ # For web channel, register request_id
+ if channel_type == "web" and hasattr(channel, 'request_to_session'):
+ request_id = context.get("request_id")
+ if request_id:
+ channel.request_to_session[request_id] = receiver
+ logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
+
+ # Send the reply
+ channel.send(reply, context)
+ logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
+ else:
+ logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
+ except Exception as e:
+ logger.error(f"[Scheduler] Failed to send result: {e}")
+ else:
+ logger.error(f"[Scheduler] Task {task['id']}: No result from agent execution")
+
+ except Exception as e:
+ logger.error(f"[Scheduler] Failed to execute task via Agent: {e}")
+ import traceback
+ logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
+
+ except Exception as e:
+ logger.error(f"[Scheduler] Error in _execute_agent_task: {e}")
+ import traceback
+ logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
+
+
+def _execute_send_message(task: dict, agent_bridge):
+ """
+ Execute a send_message action
+
+ Args:
+ task: Task dictionary
+ agent_bridge: AgentBridge instance
+ """
+ try:
+ action = task.get("action", {})
+ content = action.get("content", "")
+ receiver = action.get("receiver")
+ is_group = action.get("is_group", False)
+ channel_type = action.get("channel_type", "unknown")
+
+ if not receiver:
+ logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
+ return
+
+ # Create context for sending message
+ context = Context(ContextType.TEXT, content)
+ context["receiver"] = receiver
+ context["isgroup"] = is_group
+ context["session_id"] = receiver
+
+ # Channel-specific context setup
+ if channel_type == "web":
+ # Web channel needs request_id
+ import uuid
+ request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
+ context["request_id"] = request_id
+ logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
+ elif channel_type == "feishu":
+ # Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
+ # Use chat_id for groups, open_id for private chats
+ context["receive_id_type"] = "chat_id" if is_group else "open_id"
+ # Keep isgroup as is, but set msg to None (no original message to reply to)
+ # Feishu channel will detect this and send as new message instead of reply
+ context["msg"] = None
+ logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
+ elif channel_type == "dingtalk":
+ # DingTalk channel setup
+ context["msg"] = None
+ # 如果是单聊,需要传递 sender_staff_id
+ if not is_group:
+ sender_staff_id = action.get("dingtalk_sender_staff_id")
+ if sender_staff_id:
+ context["dingtalk_sender_staff_id"] = sender_staff_id
+ logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}")
+ else:
+ logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id")
+
+ # Create reply
+ reply = Reply(ReplyType.TEXT, content)
+
+ # Get channel and send
+ from channel.channel_factory import create_channel
+
+ try:
+ channel = create_channel(channel_type)
+ if channel:
+ # For web channel, register the request_id to session mapping
+ if channel_type == "web" and hasattr(channel, 'request_to_session'):
+ channel.request_to_session[request_id] = receiver
+ logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
+
+ channel.send(reply, context)
+ logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
+ else:
+ logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
+ except Exception as e:
+ logger.error(f"[Scheduler] Failed to send message: {e}")
+ import traceback
+ logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
+
+ except Exception as e:
+ logger.error(f"[Scheduler] Error in _execute_send_message: {e}")
+ import traceback
+ logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
+
+
+def _execute_tool_call(task: dict, agent_bridge):
+ """
+ Execute a tool_call action
+
+ Args:
+ task: Task dictionary
+ agent_bridge: AgentBridge instance
+ """
+ try:
+ action = task.get("action", {})
+ # Support both old and new field names
+ tool_name = action.get("call_name") or action.get("tool_name")
+ tool_params = action.get("call_params") or action.get("tool_params", {})
+ result_prefix = action.get("result_prefix", "")
+ receiver = action.get("receiver")
+ is_group = action.get("is_group", False)
+ channel_type = action.get("channel_type", "unknown")
+
+ if not tool_name:
+ logger.error(f"[Scheduler] Task {task['id']}: No tool_name specified")
+ return
+
+ if not receiver:
+ logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
+ return
+
+ # Get tool manager and create tool instance
+ from agent.tools.tool_manager import ToolManager
+ tool_manager = ToolManager()
+ tool = tool_manager.create_tool(tool_name)
+
+ if not tool:
+ logger.error(f"[Scheduler] Task {task['id']}: Tool '{tool_name}' not found")
+ return
+
+ # Execute tool
+ logger.info(f"[Scheduler] Task {task['id']}: Executing tool '{tool_name}' with params {tool_params}")
+ result = tool.execute(tool_params)
+
+ # Get result content
+ if hasattr(result, 'result'):
+ content = result.result
+ else:
+ content = str(result)
+
+ # Add prefix if specified
+ if result_prefix:
+ content = f"{result_prefix}\n\n{content}"
+
+ # Send result as message
+ context = Context(ContextType.TEXT, content)
+ context["receiver"] = receiver
+ context["isgroup"] = is_group
+ context["session_id"] = receiver
+
+ # Channel-specific context setup
+ if channel_type == "web":
+ # Web channel needs request_id
+ import uuid
+ request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
+ context["request_id"] = request_id
+ logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
+ elif channel_type == "feishu":
+ # Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
+ context["receive_id_type"] = "chat_id" if is_group else "open_id"
+ context["msg"] = None
+ logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
+
+ reply = Reply(ReplyType.TEXT, content)
+
+ # Get channel and send
+ from channel.channel_factory import create_channel
+
+ try:
+ channel = create_channel(channel_type)
+ if channel:
+ # For web channel, register the request_id to session mapping
+ if channel_type == "web" and hasattr(channel, 'request_to_session'):
+ channel.request_to_session[request_id] = receiver
+ logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
+
+ channel.send(reply, context)
+ logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
+ else:
+ logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
+ except Exception as e:
+ logger.error(f"[Scheduler] Failed to send tool result: {e}")
+
+ except Exception as e:
+ logger.error(f"[Scheduler] Error in _execute_tool_call: {e}")
+
+
+def _execute_skill_call(task: dict, agent_bridge):
+ """
+ Execute a skill_call action by asking Agent to run the skill
+
+ Args:
+ task: Task dictionary
+ agent_bridge: AgentBridge instance
+ """
+ try:
+ action = task.get("action", {})
+ # Support both old and new field names
+ skill_name = action.get("call_name") or action.get("skill_name")
+ skill_params = action.get("call_params") or action.get("skill_params", {})
+ result_prefix = action.get("result_prefix", "")
+ receiver = action.get("receiver")
+ is_group = action.get("isgroup", False)
+ channel_type = action.get("channel_type", "unknown")
+
+ if not skill_name:
+ logger.error(f"[Scheduler] Task {task['id']}: No skill_name specified")
+ return
+
+ if not receiver:
+ logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
+ return
+
+ logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}")
+
+ # Build a natural language query for the Agent to execute the skill
+ # Format: "Use skill-name to do something with params"
+ param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()])
+ query = f"Use {skill_name} skill"
+ if param_str:
+ query += f" with {param_str}"
+
+ # Create context for Agent
+ context = Context(ContextType.TEXT, query)
+ context["receiver"] = receiver
+ context["isgroup"] = is_group
+ context["session_id"] = receiver
+
+ # Channel-specific setup
+ if channel_type == "web":
+ import uuid
+ request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
+ context["request_id"] = request_id
+ elif channel_type == "feishu":
+ context["receive_id_type"] = "chat_id" if is_group else "open_id"
+ context["msg"] = None
+
+ # Use Agent to execute the skill
+ try:
+ reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=True)
+
+ if reply and reply.content:
+ content = reply.content
+
+ # Add prefix if specified
+ if result_prefix:
+ content = f"{result_prefix}\n\n{content}"
+
+ logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
+ else:
+ logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
+
+ except Exception as e:
+ logger.error(f"[Scheduler] Failed to execute skill via Agent: {e}")
+ import traceback
+ logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
+
+ except Exception as e:
+ logger.error(f"[Scheduler] Error in _execute_skill_call: {e}")
+ import traceback
+ logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
+
+
+def attach_scheduler_to_tool(tool, context: Context = None):
+ """
+ Attach scheduler components to a SchedulerTool instance
+
+ Args:
+ tool: SchedulerTool instance
+ context: Current context (optional)
+ """
+ if _task_store:
+ tool.task_store = _task_store
+
+ if context:
+ tool.current_context = context
+
+ # Also set channel_type from config
+ channel_type = conf().get("channel_type", "unknown")
+ if not tool.config:
+ tool.config = {}
+ tool.config["channel_type"] = channel_type
diff --git a/agent/tools/scheduler/scheduler_service.py b/agent/tools/scheduler/scheduler_service.py
new file mode 100644
index 0000000..286fbc6
--- /dev/null
+++ b/agent/tools/scheduler/scheduler_service.py
@@ -0,0 +1,220 @@
+"""
+Background scheduler service for executing scheduled tasks
+"""
+
+import time
+import threading
+from datetime import datetime, timedelta
+from typing import Callable, Optional
+from croniter import croniter
+from common.log import logger
+
+
+class SchedulerService:
+ """
+ Background service that executes scheduled tasks
+ """
+
+ def __init__(self, task_store, execute_callback: Callable):
+ """
+ Initialize scheduler service
+
+ Args:
+ task_store: TaskStore instance
+ execute_callback: Function to call when executing a task
+ """
+ self.task_store = task_store
+ self.execute_callback = execute_callback
+ self.running = False
+ self.thread = None
+ self._lock = threading.Lock()
+
+ def start(self):
+ """Start the scheduler service"""
+ with self._lock:
+ if self.running:
+ logger.warning("[Scheduler] Service already running")
+ return
+
+ self.running = True
+ self.thread = threading.Thread(target=self._run_loop, daemon=True)
+ self.thread.start()
+ logger.debug("[Scheduler] Service started")
+
+ def stop(self):
+ """Stop the scheduler service"""
+ with self._lock:
+ if not self.running:
+ return
+
+ self.running = False
+ if self.thread:
+ self.thread.join(timeout=5)
+ logger.info("[Scheduler] Service stopped")
+
+ def _run_loop(self):
+ """Main scheduler loop"""
+ logger.debug("[Scheduler] Scheduler loop started")
+
+ while self.running:
+ try:
+ self._check_and_execute_tasks()
+ except Exception as e:
+ logger.error(f"[Scheduler] Error in scheduler loop: {e}")
+
+ # Sleep for 30 seconds between checks
+ time.sleep(30)
+
+ def _check_and_execute_tasks(self):
+ """Check for due tasks and execute them"""
+ now = datetime.now()
+ tasks = self.task_store.list_tasks(enabled_only=True)
+
+ for task in tasks:
+ try:
+ # Check if task is due
+ if self._is_task_due(task, now):
+ logger.info(f"[Scheduler] Executing task: {task['id']} - {task['name']}")
+ self._execute_task(task)
+
+ # Update next run time
+ next_run = self._calculate_next_run(task, now)
+ if next_run:
+ self.task_store.update_task(task['id'], {
+ "next_run_at": next_run.isoformat(),
+ "last_run_at": now.isoformat()
+ })
+ else:
+ # One-time task, disable it
+ self.task_store.update_task(task['id'], {
+ "enabled": False,
+ "last_run_at": now.isoformat()
+ })
+ logger.info(f"[Scheduler] One-time task completed and disabled: {task['id']}")
+ except Exception as e:
+ logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}")
+
+ def _is_task_due(self, task: dict, now: datetime) -> bool:
+ """
+ Check if a task is due to run
+
+ Args:
+ task: Task dictionary
+ now: Current datetime
+
+ Returns:
+ True if task should run now
+ """
+ next_run_str = task.get("next_run_at")
+ if not next_run_str:
+ # Calculate initial next_run_at
+ next_run = self._calculate_next_run(task, now)
+ if next_run:
+ self.task_store.update_task(task['id'], {
+ "next_run_at": next_run.isoformat()
+ })
+ return False
+ return False
+
+ try:
+ next_run = datetime.fromisoformat(next_run_str)
+
+ # Check if task is overdue (e.g., service restart)
+ if next_run < now:
+ time_diff = (now - next_run).total_seconds()
+
+ # If overdue by more than 5 minutes, skip this run and schedule next
+ if time_diff > 300: # 5 minutes
+ logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run")
+
+ # For one-time tasks, disable them
+ schedule = task.get("schedule", {})
+ if schedule.get("type") == "once":
+ self.task_store.update_task(task['id'], {
+ "enabled": False,
+ "last_run_at": now.isoformat()
+ })
+ logger.info(f"[Scheduler] One-time task {task['id']} expired, disabled")
+ return False
+
+ # For recurring tasks, calculate next run from now
+ next_next_run = self._calculate_next_run(task, now)
+ if next_next_run:
+ self.task_store.update_task(task['id'], {
+ "next_run_at": next_next_run.isoformat()
+ })
+ logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
+ return False
+
+ return now >= next_run
+ except:
+ return False
+
+ def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]:
+ """
+ Calculate next run time for a task
+
+ Args:
+ task: Task dictionary
+ from_time: Calculate from this time
+
+ Returns:
+ Next run datetime or None for one-time tasks
+ """
+ schedule = task.get("schedule", {})
+ schedule_type = schedule.get("type")
+
+ if schedule_type == "cron":
+ # Cron expression
+ expression = schedule.get("expression")
+ if not expression:
+ return None
+
+ try:
+ cron = croniter(expression, from_time)
+ return cron.get_next(datetime)
+ except Exception as e:
+ logger.error(f"[Scheduler] Invalid cron expression '{expression}': {e}")
+ return None
+
+ elif schedule_type == "interval":
+ # Interval in seconds
+ seconds = schedule.get("seconds", 0)
+ if seconds <= 0:
+ return None
+ return from_time + timedelta(seconds=seconds)
+
+ elif schedule_type == "once":
+ # One-time task at specific time
+ run_at_str = schedule.get("run_at")
+ if not run_at_str:
+ return None
+
+ try:
+ run_at = datetime.fromisoformat(run_at_str)
+ # Only return if in the future
+ if run_at > from_time:
+ return run_at
+ except:
+ pass
+ return None
+
+ return None
+
+ def _execute_task(self, task: dict):
+ """
+ Execute a task
+
+ Args:
+ task: Task dictionary
+ """
+ try:
+ # Call the execute callback
+ self.execute_callback(task)
+ except Exception as e:
+ logger.error(f"[Scheduler] Error executing task {task['id']}: {e}")
+ # Update task with error
+ self.task_store.update_task(task['id'], {
+ "last_error": str(e),
+ "last_error_at": datetime.now().isoformat()
+ })
diff --git a/agent/tools/scheduler/scheduler_tool.py b/agent/tools/scheduler/scheduler_tool.py
new file mode 100644
index 0000000..9d961c3
--- /dev/null
+++ b/agent/tools/scheduler/scheduler_tool.py
@@ -0,0 +1,442 @@
+"""
+Scheduler tool for creating and managing scheduled tasks
+"""
+
+import uuid
+from datetime import datetime
+from typing import Any, Dict, Optional
+from croniter import croniter
+
+from agent.tools.base_tool import BaseTool, ToolResult
+from bridge.context import Context, ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+
+
+class SchedulerTool(BaseTool):
+ """
+ Tool for managing scheduled tasks (reminders, notifications, etc.)
+ """
+
+ name: str = "scheduler"
+ description: str = (
+ "创建、查询和管理定时任务。支持固定消息和AI任务两种类型。\n\n"
+ "使用方法:\n"
+ "- 创建:action='create', name='任务名', message/ai_task='内容', schedule_type='once/interval/cron', schedule_value='...'\n"
+ "- 查询:action='list' / action='get', task_id='任务ID'\n"
+ "- 管理:action='delete/enable/disable', task_id='任务ID'\n\n"
+ "调度类型:\n"
+ "- once: 一次性任务,支持相对时间(+5s,+10m,+1h,+1d)或ISO时间\n"
+ "- interval: 固定间隔(秒),如3600表示每小时\n"
+ "- cron: cron表达式,如'0 8 * * *'表示每天8点\n\n"
+ "注意:'X秒后'用once+相对时间,'每X秒'用interval"
+ )
+ params: dict = {
+ "type": "object",
+ "properties": {
+ "action": {
+ "type": "string",
+ "enum": ["create", "list", "get", "delete", "enable", "disable"],
+ "description": "操作类型: create(创建), list(列表), get(查询), delete(删除), enable(启用), disable(禁用)"
+ },
+ "task_id": {
+ "type": "string",
+ "description": "任务ID (用于 get/delete/enable/disable 操作)"
+ },
+ "name": {
+ "type": "string",
+ "description": "任务名称 (用于 create 操作)"
+ },
+ "message": {
+ "type": "string",
+ "description": "固定消息内容 (与ai_task二选一)"
+ },
+ "ai_task": {
+ "type": "string",
+ "description": "AI任务描述 (与message二选一),如'搜索今日新闻'、'查询天气'"
+ },
+ "schedule_type": {
+ "type": "string",
+ "enum": ["cron", "interval", "once"],
+ "description": "调度类型 (用于 create 操作): cron(cron表达式), interval(固定间隔秒数), once(一次性)"
+ },
+ "schedule_value": {
+ "type": "string",
+ "description": "调度值: cron表达式/间隔秒数/时间(+5s,+10m,+1h或ISO格式)"
+ }
+ },
+ "required": ["action"]
+ }
+
+ def __init__(self, config: dict = None):
+ super().__init__()
+ self.config = config or {}
+
+ # Will be set by agent bridge
+ self.task_store = None
+ self.current_context = None
+
+ def execute(self, params: dict) -> ToolResult:
+ """
+ Execute scheduler operations
+
+ Args:
+ params: Dictionary containing:
+ - action: Operation type (create/list/get/delete/enable/disable)
+ - Other parameters depending on action
+
+ Returns:
+ ToolResult object
+ """
+ # Extract parameters
+ action = params.get("action")
+ kwargs = params
+
+ if not self.task_store:
+ return ToolResult.fail("错误: 定时任务系统未初始化")
+
+ try:
+ if action == "create":
+ result = self._create_task(**kwargs)
+ return ToolResult.success(result)
+ elif action == "list":
+ result = self._list_tasks(**kwargs)
+ return ToolResult.success(result)
+ elif action == "get":
+ result = self._get_task(**kwargs)
+ return ToolResult.success(result)
+ elif action == "delete":
+ result = self._delete_task(**kwargs)
+ return ToolResult.success(result)
+ elif action == "enable":
+ result = self._enable_task(**kwargs)
+ return ToolResult.success(result)
+ elif action == "disable":
+ result = self._disable_task(**kwargs)
+ return ToolResult.success(result)
+ else:
+ return ToolResult.fail(f"未知操作: {action}")
+ except Exception as e:
+ logger.error(f"[SchedulerTool] Error: {e}")
+ return ToolResult.fail(f"操作失败: {str(e)}")
+
+ def _create_task(self, **kwargs) -> str:
+ """Create a new scheduled task"""
+ name = kwargs.get("name")
+ message = kwargs.get("message")
+ ai_task = kwargs.get("ai_task")
+ schedule_type = kwargs.get("schedule_type")
+ schedule_value = kwargs.get("schedule_value")
+
+ # Validate required fields
+ if not name:
+ return "错误: 缺少任务名称 (name)"
+
+ # Check that exactly one of message/ai_task is provided
+ if not message and not ai_task:
+ return "错误: 必须提供 message(固定消息)或 ai_task(AI任务)之一"
+ if message and ai_task:
+ return "错误: message 和 ai_task 只能提供其中一个"
+
+ if not schedule_type:
+ return "错误: 缺少调度类型 (schedule_type)"
+ if not schedule_value:
+ return "错误: 缺少调度值 (schedule_value)"
+
+ # Validate schedule
+ schedule = self._parse_schedule(schedule_type, schedule_value)
+ if not schedule:
+ return f"错误: 无效的调度配置 - type: {schedule_type}, value: {schedule_value}"
+
+ # Get context info for receiver
+ if not self.current_context:
+ return "错误: 无法获取当前对话上下文"
+
+ context = self.current_context
+
+ # Create task
+ task_id = str(uuid.uuid4())[:8]
+
+ # Build action based on message or ai_task
+ if message:
+ action = {
+ "type": "send_message",
+ "content": message,
+ "receiver": context.get("receiver"),
+ "receiver_name": self._get_receiver_name(context),
+ "is_group": context.get("isgroup", False),
+ "channel_type": self.config.get("channel_type", "unknown")
+ }
+ else: # ai_task
+ action = {
+ "type": "agent_task",
+ "task_description": ai_task,
+ "receiver": context.get("receiver"),
+ "receiver_name": self._get_receiver_name(context),
+ "is_group": context.get("isgroup", False),
+ "channel_type": self.config.get("channel_type", "unknown")
+ }
+
+ # 针对钉钉单聊,额外存储 sender_staff_id
+ msg = context.kwargs.get("msg")
+ if msg and hasattr(msg, 'sender_staff_id') and not context.get("isgroup", False):
+ action["dingtalk_sender_staff_id"] = msg.sender_staff_id
+
+ task_data = {
+ "id": task_id,
+ "name": name,
+ "enabled": True,
+ "created_at": datetime.now().isoformat(),
+ "updated_at": datetime.now().isoformat(),
+ "schedule": schedule,
+ "action": action
+ }
+
+ # Calculate initial next_run_at
+ next_run = self._calculate_next_run(task_data)
+ if next_run:
+ task_data["next_run_at"] = next_run.isoformat()
+
+ # Save task
+ self.task_store.add_task(task_data)
+
+ # Format response
+ schedule_desc = self._format_schedule_description(schedule)
+ receiver_desc = task_data["action"]["receiver_name"] or task_data["action"]["receiver"]
+
+ if message:
+ content_desc = f"💬 固定消息: {message}"
+ else:
+ content_desc = f"🤖 AI任务: {ai_task}"
+
+ return (
+ f"✅ 定时任务创建成功\n\n"
+ f"📋 任务ID: {task_id}\n"
+ f"📝 名称: {name}\n"
+ f"⏰ 调度: {schedule_desc}\n"
+ f"👤 接收者: {receiver_desc}\n"
+ f"{content_desc}\n"
+ f"🕐 下次执行: {next_run.strftime('%Y-%m-%d %H:%M:%S') if next_run else '未知'}"
+ )
+
+ def _list_tasks(self, **kwargs) -> str:
+ """List all tasks"""
+ tasks = self.task_store.list_tasks()
+
+ if not tasks:
+ return "📋 暂无定时任务"
+
+ lines = [f"📋 定时任务列表 (共 {len(tasks)} 个)\n"]
+
+ for task in tasks:
+ status = "✅" if task.get("enabled", True) else "❌"
+ schedule_desc = self._format_schedule_description(task.get("schedule", {}))
+ next_run = task.get("next_run_at")
+ next_run_str = datetime.fromisoformat(next_run).strftime('%m-%d %H:%M') if next_run else "未知"
+
+ lines.append(
+ f"{status} [{task['id']}] {task['name']}\n"
+ f" ⏰ {schedule_desc} | 下次: {next_run_str}"
+ )
+
+ return "\n".join(lines)
+
+ def _get_task(self, **kwargs) -> str:
+ """Get task details"""
+ task_id = kwargs.get("task_id")
+ if not task_id:
+ return "错误: 缺少任务ID (task_id)"
+
+ task = self.task_store.get_task(task_id)
+ if not task:
+ return f"错误: 任务 '{task_id}' 不存在"
+
+ status = "启用" if task.get("enabled", True) else "禁用"
+ schedule_desc = self._format_schedule_description(task.get("schedule", {}))
+ action = task.get("action", {})
+ next_run = task.get("next_run_at")
+ next_run_str = datetime.fromisoformat(next_run).strftime('%Y-%m-%d %H:%M:%S') if next_run else "未知"
+ last_run = task.get("last_run_at")
+ last_run_str = datetime.fromisoformat(last_run).strftime('%Y-%m-%d %H:%M:%S') if last_run else "从未执行"
+
+ return (
+ f"📋 任务详情\n\n"
+ f"ID: {task['id']}\n"
+ f"名称: {task['name']}\n"
+ f"状态: {status}\n"
+ f"调度: {schedule_desc}\n"
+ f"接收者: {action.get('receiver_name', action.get('receiver'))}\n"
+ f"消息: {action.get('content')}\n"
+ f"下次执行: {next_run_str}\n"
+ f"上次执行: {last_run_str}\n"
+ f"创建时间: {datetime.fromisoformat(task['created_at']).strftime('%Y-%m-%d %H:%M:%S')}"
+ )
+
+ def _delete_task(self, **kwargs) -> str:
+ """Delete a task"""
+ task_id = kwargs.get("task_id")
+ if not task_id:
+ return "错误: 缺少任务ID (task_id)"
+
+ task = self.task_store.get_task(task_id)
+ if not task:
+ return f"错误: 任务 '{task_id}' 不存在"
+
+ self.task_store.delete_task(task_id)
+ return f"✅ 任务 '{task['name']}' ({task_id}) 已删除"
+
+ def _enable_task(self, **kwargs) -> str:
+ """Enable a task"""
+ task_id = kwargs.get("task_id")
+ if not task_id:
+ return "错误: 缺少任务ID (task_id)"
+
+ task = self.task_store.get_task(task_id)
+ if not task:
+ return f"错误: 任务 '{task_id}' 不存在"
+
+ self.task_store.enable_task(task_id, True)
+ return f"✅ 任务 '{task['name']}' ({task_id}) 已启用"
+
+ def _disable_task(self, **kwargs) -> str:
+ """Disable a task"""
+ task_id = kwargs.get("task_id")
+ if not task_id:
+ return "错误: 缺少任务ID (task_id)"
+
+ task = self.task_store.get_task(task_id)
+ if not task:
+ return f"错误: 任务 '{task_id}' 不存在"
+
+ self.task_store.enable_task(task_id, False)
+ return f"✅ 任务 '{task['name']}' ({task_id}) 已禁用"
+
+ def _parse_schedule(self, schedule_type: str, schedule_value: str) -> Optional[dict]:
+ """Parse and validate schedule configuration"""
+ try:
+ if schedule_type == "cron":
+ # Validate cron expression
+ croniter(schedule_value)
+ return {"type": "cron", "expression": schedule_value}
+
+ elif schedule_type == "interval":
+ # Parse interval in seconds
+ seconds = int(schedule_value)
+ if seconds <= 0:
+ return None
+ return {"type": "interval", "seconds": seconds}
+
+ elif schedule_type == "once":
+ # Parse datetime - support both relative and absolute time
+
+ # Check if it's relative time (e.g., "+5s", "+10m", "+1h", "+1d")
+ if schedule_value.startswith("+"):
+ import re
+ match = re.match(r'\+(\d+)([smhd])', schedule_value)
+ if match:
+ amount = int(match.group(1))
+ unit = match.group(2)
+
+ from datetime import timedelta
+ now = datetime.now()
+
+ if unit == 's': # seconds
+ target_time = now + timedelta(seconds=amount)
+ elif unit == 'm': # minutes
+ target_time = now + timedelta(minutes=amount)
+ elif unit == 'h': # hours
+ target_time = now + timedelta(hours=amount)
+ elif unit == 'd': # days
+ target_time = now + timedelta(days=amount)
+ else:
+ return None
+
+ return {"type": "once", "run_at": target_time.isoformat()}
+ else:
+ logger.error(f"[SchedulerTool] Invalid relative time format: {schedule_value}")
+ return None
+ else:
+ # Absolute time in ISO format
+ datetime.fromisoformat(schedule_value)
+ return {"type": "once", "run_at": schedule_value}
+
+ except Exception as e:
+ logger.error(f"[SchedulerTool] Invalid schedule: {e}")
+ return None
+
+ return None
+
+ def _calculate_next_run(self, task: dict) -> Optional[datetime]:
+ """Calculate next run time for a task"""
+ schedule = task.get("schedule", {})
+ schedule_type = schedule.get("type")
+ now = datetime.now()
+
+ if schedule_type == "cron":
+ expression = schedule.get("expression")
+ cron = croniter(expression, now)
+ return cron.get_next(datetime)
+
+ elif schedule_type == "interval":
+ seconds = schedule.get("seconds", 0)
+ from datetime import timedelta
+ return now + timedelta(seconds=seconds)
+
+ elif schedule_type == "once":
+ run_at_str = schedule.get("run_at")
+ return datetime.fromisoformat(run_at_str)
+
+ return None
+
+ def _format_schedule_description(self, schedule: dict) -> str:
+ """Format schedule as human-readable description"""
+ schedule_type = schedule.get("type")
+
+ if schedule_type == "cron":
+ expr = schedule.get("expression", "")
+ # Try to provide friendly description
+ if expr == "0 9 * * *":
+ return "每天 9:00"
+ elif expr == "0 */1 * * *":
+ return "每小时"
+ elif expr == "*/30 * * * *":
+ return "每30分钟"
+ else:
+ return f"Cron: {expr}"
+
+ elif schedule_type == "interval":
+ seconds = schedule.get("seconds", 0)
+ if seconds >= 86400:
+ days = seconds // 86400
+ return f"每 {days} 天"
+ elif seconds >= 3600:
+ hours = seconds // 3600
+ return f"每 {hours} 小时"
+ elif seconds >= 60:
+ minutes = seconds // 60
+ return f"每 {minutes} 分钟"
+ else:
+ return f"每 {seconds} 秒"
+
+ elif schedule_type == "once":
+ run_at = schedule.get("run_at", "")
+ try:
+ dt = datetime.fromisoformat(run_at)
+ return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})"
+ except:
+ return "一次性"
+
+ return "未知"
+
+ def _get_receiver_name(self, context: Context) -> str:
+ """Get receiver name from context"""
+ try:
+ msg = context.get("msg")
+ if msg:
+ if context.get("isgroup"):
+ return msg.other_user_nickname or "群聊"
+ else:
+ return msg.from_user_nickname or "用户"
+ except:
+ pass
+ return "未知"
diff --git a/agent/tools/scheduler/task_store.py b/agent/tools/scheduler/task_store.py
new file mode 100644
index 0000000..55e84a1
--- /dev/null
+++ b/agent/tools/scheduler/task_store.py
@@ -0,0 +1,200 @@
+"""
+Task storage management for scheduler
+"""
+
+import json
+import os
+import threading
+from datetime import datetime
+from typing import Dict, List, Optional
+from pathlib import Path
+
+
+class TaskStore:
+ """
+ Manages persistent storage of scheduled tasks
+ """
+
+ def __init__(self, store_path: str = None):
+ """
+ Initialize task store
+
+ Args:
+ store_path: Path to tasks.json file. Defaults to ~/cow/scheduler/tasks.json
+ """
+ if store_path is None:
+ # Default to ~/cow/scheduler/tasks.json
+ home = os.path.expanduser("~")
+ store_path = os.path.join(home, "cow", "scheduler", "tasks.json")
+
+ self.store_path = store_path
+ self.lock = threading.Lock()
+ self._ensure_store_dir()
+
+ def _ensure_store_dir(self):
+ """Ensure the storage directory exists"""
+ store_dir = os.path.dirname(self.store_path)
+ os.makedirs(store_dir, exist_ok=True)
+
+ def load_tasks(self) -> Dict[str, dict]:
+ """
+ Load all tasks from storage
+
+ Returns:
+ Dictionary of task_id -> task_data
+ """
+ with self.lock:
+ if not os.path.exists(self.store_path):
+ return {}
+
+ try:
+ with open(self.store_path, 'r', encoding='utf-8') as f:
+ data = json.load(f)
+ return data.get("tasks", {})
+ except Exception as e:
+ print(f"Error loading tasks: {e}")
+ return {}
+
+ def save_tasks(self, tasks: Dict[str, dict]):
+ """
+ Save all tasks to storage
+
+ Args:
+ tasks: Dictionary of task_id -> task_data
+ """
+ with self.lock:
+ try:
+ # Create backup
+ if os.path.exists(self.store_path):
+ backup_path = f"{self.store_path}.bak"
+ try:
+ with open(self.store_path, 'r') as src:
+ with open(backup_path, 'w') as dst:
+ dst.write(src.read())
+ except:
+ pass
+
+ # Save tasks
+ data = {
+ "version": 1,
+ "updated_at": datetime.now().isoformat(),
+ "tasks": tasks
+ }
+
+ with open(self.store_path, 'w', encoding='utf-8') as f:
+ json.dump(data, f, ensure_ascii=False, indent=2)
+ except Exception as e:
+ print(f"Error saving tasks: {e}")
+ raise
+
+ def add_task(self, task: dict) -> bool:
+ """
+ Add a new task
+
+ Args:
+ task: Task data dictionary
+
+ Returns:
+ True if successful
+ """
+ tasks = self.load_tasks()
+ task_id = task.get("id")
+
+ if not task_id:
+ raise ValueError("Task must have an 'id' field")
+
+ if task_id in tasks:
+ raise ValueError(f"Task with id '{task_id}' already exists")
+
+ tasks[task_id] = task
+ self.save_tasks(tasks)
+ return True
+
+ def update_task(self, task_id: str, updates: dict) -> bool:
+ """
+ Update an existing task
+
+ Args:
+ task_id: Task ID
+ updates: Dictionary of fields to update
+
+ Returns:
+ True if successful
+ """
+ tasks = self.load_tasks()
+
+ if task_id not in tasks:
+ raise ValueError(f"Task '{task_id}' not found")
+
+ # Update fields
+ tasks[task_id].update(updates)
+ tasks[task_id]["updated_at"] = datetime.now().isoformat()
+
+ self.save_tasks(tasks)
+ return True
+
+ def delete_task(self, task_id: str) -> bool:
+ """
+ Delete a task
+
+ Args:
+ task_id: Task ID
+
+ Returns:
+ True if successful
+ """
+ tasks = self.load_tasks()
+
+ if task_id not in tasks:
+ raise ValueError(f"Task '{task_id}' not found")
+
+ del tasks[task_id]
+ self.save_tasks(tasks)
+ return True
+
+ def get_task(self, task_id: str) -> Optional[dict]:
+ """
+ Get a specific task
+
+ Args:
+ task_id: Task ID
+
+ Returns:
+ Task data or None if not found
+ """
+ tasks = self.load_tasks()
+ return tasks.get(task_id)
+
+ def list_tasks(self, enabled_only: bool = False) -> List[dict]:
+ """
+ List all tasks
+
+ Args:
+ enabled_only: If True, only return enabled tasks
+
+ Returns:
+ List of task dictionaries
+ """
+ tasks = self.load_tasks()
+ task_list = list(tasks.values())
+
+ if enabled_only:
+ task_list = [t for t in task_list if t.get("enabled", True)]
+
+ # Sort by next_run_at
+ task_list.sort(key=lambda t: t.get("next_run_at", float('inf')))
+
+ return task_list
+
+ def enable_task(self, task_id: str, enabled: bool = True) -> bool:
+ """
+ Enable or disable a task
+
+ Args:
+ task_id: Task ID
+ enabled: True to enable, False to disable
+
+ Returns:
+ True if successful
+ """
+ return self.update_task(task_id, {"enabled": enabled})
diff --git a/agent/tools/send/__init__.py b/agent/tools/send/__init__.py
new file mode 100644
index 0000000..b76702a
--- /dev/null
+++ b/agent/tools/send/__init__.py
@@ -0,0 +1,3 @@
+from .send import Send
+
+__all__ = ['Send']
diff --git a/agent/tools/send/send.py b/agent/tools/send/send.py
new file mode 100644
index 0000000..a778b74
--- /dev/null
+++ b/agent/tools/send/send.py
@@ -0,0 +1,159 @@
+"""
+Send tool - Send files to the user
+"""
+
+import os
+from typing import Dict, Any
+from pathlib import Path
+
+from agent.tools.base_tool import BaseTool, ToolResult
+
+
+class Send(BaseTool):
+ """Tool for sending files to the user"""
+
+ name: str = "send"
+ description: str = "Send a file (image, video, audio, document) to the user. Use this when the user explicitly asks to send/share a file."
+
+ params: dict = {
+ "type": "object",
+ "properties": {
+ "path": {
+ "type": "string",
+ "description": "Path to the file to send. Can be absolute path or relative to workspace."
+ },
+ "message": {
+ "type": "string",
+ "description": "Optional message to accompany the file"
+ }
+ },
+ "required": ["path"]
+ }
+
+ def __init__(self, config: dict = None):
+ self.config = config or {}
+ self.cwd = self.config.get("cwd", os.getcwd())
+
+ # Supported file types
+ self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
+ self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
+ self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
+ self.document_extensions = {'.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.txt', '.md'}
+
+ def execute(self, args: Dict[str, Any]) -> ToolResult:
+ """
+ Execute file send operation
+
+ :param args: Contains file path and optional message
+ :return: File metadata for channel to send
+ """
+ path = args.get("path", "").strip()
+ message = args.get("message", "")
+
+ if not path:
+ return ToolResult.fail("Error: path parameter is required")
+
+ # Resolve path
+ absolute_path = self._resolve_path(path)
+
+ # Check if file exists
+ if not os.path.exists(absolute_path):
+ return ToolResult.fail(f"Error: File not found: {path}")
+
+ # Check if readable
+ if not os.access(absolute_path, os.R_OK):
+ return ToolResult.fail(f"Error: File is not readable: {path}")
+
+ # Get file info
+ file_ext = Path(absolute_path).suffix.lower()
+ file_size = os.path.getsize(absolute_path)
+ file_name = Path(absolute_path).name
+
+ # Determine file type
+ if file_ext in self.image_extensions:
+ file_type = "image"
+ mime_type = self._get_image_mime_type(file_ext)
+ elif file_ext in self.video_extensions:
+ file_type = "video"
+ mime_type = self._get_video_mime_type(file_ext)
+ elif file_ext in self.audio_extensions:
+ file_type = "audio"
+ mime_type = self._get_audio_mime_type(file_ext)
+ elif file_ext in self.document_extensions:
+ file_type = "document"
+ mime_type = self._get_document_mime_type(file_ext)
+ else:
+ file_type = "file"
+ mime_type = "application/octet-stream"
+
+ # Return file_to_send metadata
+ result = {
+ "type": "file_to_send",
+ "file_type": file_type,
+ "path": absolute_path,
+ "file_name": file_name,
+ "mime_type": mime_type,
+ "size": file_size,
+ "size_formatted": self._format_size(file_size),
+ "message": message or f"正在发送 {file_name}"
+ }
+
+ return ToolResult.success(result)
+
+ def _resolve_path(self, path: str) -> str:
+ """Resolve path to absolute path"""
+ path = os.path.expanduser(path)
+ if os.path.isabs(path):
+ return path
+ return os.path.abspath(os.path.join(self.cwd, path))
+
+ def _get_image_mime_type(self, ext: str) -> str:
+ """Get MIME type for image"""
+ mime_map = {
+ '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg',
+ '.png': 'image/png', '.gif': 'image/gif',
+ '.webp': 'image/webp', '.bmp': 'image/bmp',
+ '.svg': 'image/svg+xml', '.ico': 'image/x-icon'
+ }
+ return mime_map.get(ext, 'image/jpeg')
+
+ def _get_video_mime_type(self, ext: str) -> str:
+ """Get MIME type for video"""
+ mime_map = {
+ '.mp4': 'video/mp4', '.avi': 'video/x-msvideo',
+ '.mov': 'video/quicktime', '.mkv': 'video/x-matroska',
+ '.webm': 'video/webm', '.flv': 'video/x-flv'
+ }
+ return mime_map.get(ext, 'video/mp4')
+
+ def _get_audio_mime_type(self, ext: str) -> str:
+ """Get MIME type for audio"""
+ mime_map = {
+ '.mp3': 'audio/mpeg', '.wav': 'audio/wav',
+ '.ogg': 'audio/ogg', '.m4a': 'audio/mp4',
+ '.flac': 'audio/flac', '.aac': 'audio/aac'
+ }
+ return mime_map.get(ext, 'audio/mpeg')
+
+ def _get_document_mime_type(self, ext: str) -> str:
+ """Get MIME type for document"""
+ mime_map = {
+ '.pdf': 'application/pdf',
+ '.doc': 'application/msword',
+ '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
+ '.xls': 'application/vnd.ms-excel',
+ '.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
+ '.ppt': 'application/vnd.ms-powerpoint',
+ '.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
+ '.txt': 'text/plain',
+ '.md': 'text/markdown'
+ }
+ return mime_map.get(ext, 'application/octet-stream')
+
+ def _format_size(self, size_bytes: int) -> str:
+ """Format file size in human-readable format"""
+ for unit in ['B', 'KB', 'MB', 'GB']:
+ if size_bytes < 1024.0:
+ return f"{size_bytes:.1f}{unit}"
+ size_bytes /= 1024.0
+ return f"{size_bytes:.1f}TB"
diff --git a/agent/tools/web_fetch/README.md b/agent/tools/web_fetch/README.md
deleted file mode 100644
index 6fc192f..0000000
--- a/agent/tools/web_fetch/README.md
+++ /dev/null
@@ -1,212 +0,0 @@
-# WebFetch Tool
-
-免费的网页抓取工具,无需 API Key,可直接抓取网页内容并提取可读文本。
-
-## 功能特性
-
-- ✅ **完全免费** - 无需任何 API Key
-- 🌐 **智能提取** - 自动提取网页主要内容
-- 📝 **格式转换** - 支持 HTML → Markdown/Text
-- 🚀 **高性能** - 内置请求重试和超时控制
-- 🎯 **智能降级** - 优先使用 Readability,可降级到基础提取
-
-## 安装依赖
-
-### 基础功能(必需)
-```bash
-pip install requests
-```
-
-### 增强功能(推荐)
-```bash
-# 安装 readability-lxml 以获得更好的内容提取效果
-pip install readability-lxml
-
-# 安装 html2text 以获得更好的 Markdown 转换
-pip install html2text
-```
-
-## 使用方法
-
-### 1. 在代码中使用
-
-```python
-from agent.tools.web_fetch import WebFetch
-
-# 创建工具实例
-tool = WebFetch()
-
-# 抓取网页(默认返回 Markdown 格式)
-result = tool.execute({
- "url": "https://example.com"
-})
-
-# 抓取并转换为纯文本
-result = tool.execute({
- "url": "https://example.com",
- "extract_mode": "text",
- "max_chars": 5000
-})
-
-if result.status == "success":
- data = result.result
- print(f"标题: {data['title']}")
- print(f"内容: {data['text']}")
-```
-
-### 2. 在 Agent 中使用
-
-工具会自动加载到 Agent 的工具列表中:
-
-```python
-from agent.tools import WebFetch
-
-tools = [
- WebFetch(),
- # ... 其他工具
-]
-
-agent = create_agent(tools=tools)
-```
-
-### 3. 通过 Skills 使用
-
-创建一个 skill 文件 `skills/web-fetch/SKILL.md`:
-
-```markdown
----
-name: web-fetch
-emoji: 🌐
-always: true
----
-
-# 网页内容获取
-
-使用 web_fetch 工具获取网页内容。
-
-## 使用场景
-
-- 需要读取某个网页的内容
-- 需要提取文章正文
-- 需要获取网页信息
-
-## 示例
-
-
-用户: 帮我看看 https://example.com 这个网页讲了什么
-助手:
- https://example.com
- markdown
-
-
-```
-
-## 参数说明
-
-| 参数 | 类型 | 必需 | 默认值 | 说明 |
-|------|------|------|--------|------|
-| `url` | string | ✅ | - | 要抓取的 URL(http/https) |
-| `extract_mode` | string | ❌ | `markdown` | 提取模式:`markdown` 或 `text` |
-| `max_chars` | integer | ❌ | `50000` | 最大返回字符数(最小 100) |
-
-## 返回结果
-
-```python
-{
- "url": "https://example.com", # 最终 URL(处理重定向后)
- "status": 200, # HTTP 状态码
- "content_type": "text/html", # 内容类型
- "title": "Example Domain", # 页面标题
- "extractor": "readability", # 提取器:readability/basic/raw
- "extract_mode": "markdown", # 提取模式
- "text": "# Example Domain\n\n...", # 提取的文本内容
- "length": 1234, # 文本长度
- "truncated": false, # 是否被截断
- "warning": "..." # 警告信息(如果有)
-}
-```
-
-## 与其他搜索工具的对比
-
-| 工具 | 需要 API Key | 功能 | 成本 |
-|------|-------------|------|------|
-| `web_fetch` | ❌ 不需要 | 抓取指定 URL 的内容 | 免费 |
-| `web_search` (Brave) | ✅ 需要 | 搜索引擎查询 | 有免费额度 |
-| `web_search` (Perplexity) | ✅ 需要 | AI 搜索 + 引用 | 付费 |
-| `browser` | ❌ 不需要 | 完整浏览器自动化 | 免费但资源占用大 |
-| `google_search` | ✅ 需要 | Google 搜索 API | 付费 |
-
-## 技术细节
-
-### 内容提取策略
-
-1. **Readability 模式**(推荐)
- - 使用 Mozilla 的 Readability 算法
- - 自动识别文章主体内容
- - 过滤广告、导航栏等噪音
-
-2. **Basic 模式**(降级)
- - 简单的 HTML 标签清理
- - 正则表达式提取文本
- - 适用于简单页面
-
-3. **Raw 模式**
- - 用于非 HTML 内容
- - 直接返回原始内容
-
-### 错误处理
-
-工具会自动处理以下情况:
-- ✅ HTTP 重定向(最多 3 次)
-- ✅ 请求超时(默认 30 秒)
-- ✅ 网络错误自动重试
-- ✅ 内容提取失败降级
-
-## 测试
-
-运行测试脚本:
-
-```bash
-cd agent/tools/web_fetch
-python test_web_fetch.py
-```
-
-## 配置选项
-
-在创建工具时可以传入配置:
-
-```python
-tool = WebFetch(config={
- "timeout": 30, # 请求超时时间(秒)
- "max_redirects": 3, # 最大重定向次数
- "user_agent": "..." # 自定义 User-Agent
-})
-```
-
-## 常见问题
-
-### Q: 为什么推荐安装 readability-lxml?
-
-A: readability-lxml 提供更好的内容提取质量,能够:
-- 自动识别文章主体
-- 过滤广告和导航栏
-- 保留文章结构
-
-没有它也能工作,但提取质量会下降。
-
-### Q: 与 clawdbot 的 web_fetch 有什么区别?
-
-A: 本实现参考了 clawdbot 的设计,主要区别:
-- Python 实现(clawdbot 是 TypeScript)
-- 简化了一些高级特性(如 Firecrawl 集成)
-- 保留了核心的免费功能
-- 更容易集成到现有项目
-
-### Q: 可以抓取需要登录的页面吗?
-
-A: 当前版本不支持。如需抓取需要登录的页面,请使用 `browser` 工具。
-
-## 参考
-
-- [Mozilla Readability](https://github.com/mozilla/readability)
-- [Clawdbot Web Tools](https://github.com/moltbot/moltbot)
diff --git a/agent/tools/web_fetch/__init__.py b/agent/tools/web_fetch/__init__.py
deleted file mode 100644
index 545f87c..0000000
--- a/agent/tools/web_fetch/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .web_fetch import WebFetch
-
-__all__ = ['WebFetch']
diff --git a/agent/tools/web_fetch/install_deps.sh b/agent/tools/web_fetch/install_deps.sh
deleted file mode 100644
index 3c2a553..0000000
--- a/agent/tools/web_fetch/install_deps.sh
+++ /dev/null
@@ -1,47 +0,0 @@
-#!/bin/bash
-
-# WebFetch 工具依赖安装脚本
-
-echo "=================================="
-echo "WebFetch 工具依赖安装"
-echo "=================================="
-echo ""
-
-# 检查 Python 版本
-python_version=$(python3 --version 2>&1 | awk '{print $2}')
-echo "✓ Python 版本: $python_version"
-echo ""
-
-# 安装基础依赖
-echo "📦 安装基础依赖..."
-python3 -m pip install requests
-
-# 检查是否成功
-if [ $? -eq 0 ]; then
- echo "✅ requests 安装成功"
-else
- echo "❌ requests 安装失败"
- exit 1
-fi
-
-echo ""
-
-# 安装推荐依赖
-echo "📦 安装推荐依赖(提升内容提取质量)..."
-python3 -m pip install readability-lxml html2text
-
-# 检查是否成功
-if [ $? -eq 0 ]; then
- echo "✅ readability-lxml 和 html2text 安装成功"
-else
- echo "⚠️ 推荐依赖安装失败,但不影响基础功能"
-fi
-
-echo ""
-echo "=================================="
-echo "安装完成!"
-echo "=================================="
-echo ""
-echo "运行测试:"
-echo " python3 agent/tools/web_fetch/test_web_fetch.py"
-echo ""
diff --git a/agent/tools/web_fetch/web_fetch.py b/agent/tools/web_fetch/web_fetch.py
deleted file mode 100644
index b87b95e..0000000
--- a/agent/tools/web_fetch/web_fetch.py
+++ /dev/null
@@ -1,365 +0,0 @@
-"""
-Web Fetch tool - Fetch and extract readable content from URLs
-Supports HTML to Markdown/Text conversion using Mozilla's Readability
-"""
-
-import os
-import re
-from typing import Dict, Any, Optional
-from urllib.parse import urlparse
-import requests
-from requests.adapters import HTTPAdapter
-from urllib3.util.retry import Retry
-
-from agent.tools.base_tool import BaseTool, ToolResult
-from common.log import logger
-
-
-class WebFetch(BaseTool):
- """Tool for fetching and extracting readable content from web pages"""
-
- name: str = "web_fetch"
- description: str = "Fetch and extract readable content from a URL (HTML → markdown/text). Use for lightweight page access without browser automation. Returns title, content, and metadata."
-
- params: dict = {
- "type": "object",
- "properties": {
- "url": {
- "type": "string",
- "description": "HTTP or HTTPS URL to fetch"
- },
- "extract_mode": {
- "type": "string",
- "description": "Extraction mode: 'markdown' (default) or 'text'",
- "enum": ["markdown", "text"],
- "default": "markdown"
- },
- "max_chars": {
- "type": "integer",
- "description": "Maximum characters to return (default: 50000)",
- "minimum": 100,
- "default": 50000
- }
- },
- "required": ["url"]
- }
-
- def __init__(self, config: dict = None):
- self.config = config or {}
- self.timeout = self.config.get("timeout", 20)
- self.max_redirects = self.config.get("max_redirects", 3)
- self.user_agent = self.config.get(
- "user_agent",
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36"
- )
-
- # Setup session with retry strategy
- self.session = self._create_session()
-
- # Check if readability-lxml is available
- self.readability_available = self._check_readability()
-
- def _create_session(self) -> requests.Session:
- """Create a requests session with retry strategy"""
- session = requests.Session()
-
- # Retry strategy - handles failed requests, not redirects
- retry_strategy = Retry(
- total=3,
- backoff_factor=1,
- status_forcelist=[429, 500, 502, 503, 504],
- allowed_methods=["GET", "HEAD"]
- )
-
- # HTTPAdapter handles retries; requests handles redirects via allow_redirects
- adapter = HTTPAdapter(max_retries=retry_strategy)
- session.mount("http://", adapter)
- session.mount("https://", adapter)
-
- # Set max redirects on session
- session.max_redirects = self.max_redirects
-
- return session
-
- def _check_readability(self) -> bool:
- """Check if readability-lxml is available"""
- try:
- from readability import Document
- return True
- except ImportError:
- logger.warning(
- "readability-lxml not installed. Install with: pip install readability-lxml\n"
- "Falling back to basic HTML extraction."
- )
- return False
-
- def execute(self, args: Dict[str, Any]) -> ToolResult:
- """
- Execute web fetch operation
-
- :param args: Contains url, extract_mode, and max_chars parameters
- :return: Extracted content or error message
- """
- url = args.get("url", "").strip()
- extract_mode = args.get("extract_mode", "markdown").lower()
- max_chars = args.get("max_chars", 50000)
-
- if not url:
- return ToolResult.fail("Error: url parameter is required")
-
- # Validate URL
- if not self._is_valid_url(url):
- return ToolResult.fail(f"Error: Invalid URL (must be http or https): {url}")
-
- # Validate extract_mode
- if extract_mode not in ["markdown", "text"]:
- extract_mode = "markdown"
-
- # Validate max_chars
- if not isinstance(max_chars, int) or max_chars < 100:
- max_chars = 50000
-
- try:
- # Fetch the URL
- response = self._fetch_url(url)
-
- # Extract content
- result = self._extract_content(
- html=response.text,
- url=response.url,
- status_code=response.status_code,
- content_type=response.headers.get("content-type", ""),
- extract_mode=extract_mode,
- max_chars=max_chars
- )
-
- return ToolResult.success(result)
-
- except requests.exceptions.Timeout:
- return ToolResult.fail(f"Error: Request timeout after {self.timeout} seconds")
- except requests.exceptions.TooManyRedirects:
- return ToolResult.fail(f"Error: Too many redirects (limit: {self.max_redirects})")
- except requests.exceptions.RequestException as e:
- return ToolResult.fail(f"Error fetching URL: {str(e)}")
- except Exception as e:
- logger.error(f"Web fetch error: {e}", exc_info=True)
- return ToolResult.fail(f"Error: {str(e)}")
-
- def _is_valid_url(self, url: str) -> bool:
- """Validate URL format"""
- try:
- result = urlparse(url)
- return result.scheme in ["http", "https"] and bool(result.netloc)
- except Exception:
- return False
-
- def _fetch_url(self, url: str) -> requests.Response:
- """
- Fetch URL with proper headers and error handling
-
- :param url: URL to fetch
- :return: Response object
- """
- headers = {
- "User-Agent": self.user_agent,
- "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
- "Accept-Language": "en-US,en;q=0.9,zh-CN,zh;q=0.8",
- "Accept-Encoding": "gzip, deflate",
- "Connection": "keep-alive",
- }
-
- # Note: requests library handles redirects automatically
- # The max_redirects is set in the session's adapter (HTTPAdapter)
- response = self.session.get(
- url,
- headers=headers,
- timeout=self.timeout,
- allow_redirects=True
- )
-
- response.raise_for_status()
- return response
-
- def _extract_content(
- self,
- html: str,
- url: str,
- status_code: int,
- content_type: str,
- extract_mode: str,
- max_chars: int
- ) -> Dict[str, Any]:
- """
- Extract readable content from HTML
-
- :param html: HTML content
- :param url: Original URL
- :param status_code: HTTP status code
- :param content_type: Content type header
- :param extract_mode: 'markdown' or 'text'
- :param max_chars: Maximum characters to return
- :return: Extracted content and metadata
- """
- # Check content type
- if "text/html" not in content_type.lower():
- # Non-HTML content
- text = html[:max_chars]
- truncated = len(html) > max_chars
-
- return {
- "url": url,
- "status": status_code,
- "content_type": content_type,
- "extractor": "raw",
- "text": text,
- "length": len(text),
- "truncated": truncated,
- "message": f"Non-HTML content (type: {content_type})"
- }
-
- # Extract readable content from HTML
- if self.readability_available:
- return self._extract_with_readability(
- html, url, status_code, content_type, extract_mode, max_chars
- )
- else:
- return self._extract_basic(
- html, url, status_code, content_type, extract_mode, max_chars
- )
-
- def _extract_with_readability(
- self,
- html: str,
- url: str,
- status_code: int,
- content_type: str,
- extract_mode: str,
- max_chars: int
- ) -> Dict[str, Any]:
- """Extract content using Mozilla's Readability"""
- try:
- from readability import Document
-
- # Parse with Readability
- doc = Document(html)
- title = doc.title()
- content_html = doc.summary()
-
- # Convert to markdown or text
- if extract_mode == "markdown":
- text = self._html_to_markdown(content_html)
- else:
- text = self._html_to_text(content_html)
-
- # Truncate if needed
- truncated = len(text) > max_chars
- if truncated:
- text = text[:max_chars]
-
- return {
- "url": url,
- "status": status_code,
- "content_type": content_type,
- "title": title,
- "extractor": "readability",
- "extract_mode": extract_mode,
- "text": text,
- "length": len(text),
- "truncated": truncated
- }
-
- except Exception as e:
- logger.warning(f"Readability extraction failed: {e}")
- # Fallback to basic extraction
- return self._extract_basic(
- html, url, status_code, content_type, extract_mode, max_chars
- )
-
- def _extract_basic(
- self,
- html: str,
- url: str,
- status_code: int,
- content_type: str,
- extract_mode: str,
- max_chars: int
- ) -> Dict[str, Any]:
- """Basic HTML extraction without Readability"""
- # Extract title
- title_match = re.search(r'
]*>(.*?)', html, re.IGNORECASE | re.DOTALL)
- title = title_match.group(1).strip() if title_match else "Untitled"
-
- # Remove script and style tags
- text = re.sub(r'', '', html, flags=re.DOTALL | re.IGNORECASE)
- text = re.sub(r'', '', text, flags=re.DOTALL | re.IGNORECASE)
-
- # Remove HTML tags
- text = re.sub(r'<[^>]+>', ' ', text)
-
- # Clean up whitespace
- text = re.sub(r'\s+', ' ', text)
- text = text.strip()
-
- # Truncate if needed
- truncated = len(text) > max_chars
- if truncated:
- text = text[:max_chars]
-
- return {
- "url": url,
- "status": status_code,
- "content_type": content_type,
- "title": title,
- "extractor": "basic",
- "extract_mode": extract_mode,
- "text": text,
- "length": len(text),
- "truncated": truncated,
- "warning": "Using basic extraction. Install readability-lxml for better results."
- }
-
- def _html_to_markdown(self, html: str) -> str:
- """Convert HTML to Markdown (basic implementation)"""
- try:
- # Try to use html2text if available
- import html2text
- h = html2text.HTML2Text()
- h.ignore_links = False
- h.ignore_images = False
- h.body_width = 0 # Don't wrap lines
- return h.handle(html)
- except ImportError:
- # Fallback to basic conversion
- return self._html_to_text(html)
-
- def _html_to_text(self, html: str) -> str:
- """Convert HTML to plain text"""
- # Remove script and style tags
- text = re.sub(r'', '', html, flags=re.DOTALL | re.IGNORECASE)
- text = re.sub(r'', '', text, flags=re.DOTALL | re.IGNORECASE)
-
- # Convert common tags to text equivalents
- text = re.sub(r'
', '\n', text, flags=re.IGNORECASE)
- text = re.sub(r']*>', '\n\n', text, flags=re.IGNORECASE)
- text = re.sub(r'
', '', text, flags=re.IGNORECASE)
- text = re.sub(r']*>', '\n\n', text, flags=re.IGNORECASE)
- text = re.sub(r'', '\n', text, flags=re.IGNORECASE)
-
- # Remove all other HTML tags
- text = re.sub(r'<[^>]+>', '', text)
-
- # Decode HTML entities
- import html
- text = html.unescape(text)
-
- # Clean up whitespace
- text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
- text = re.sub(r' +', ' ', text)
- text = text.strip()
-
- return text
-
- def close(self):
- """Close the session"""
- if hasattr(self, 'session'):
- self.session.close()
diff --git a/agent/tools/write/write.py b/agent/tools/write/write.py
index 9836564..49e01c8 100644
--- a/agent/tools/write/write.py
+++ b/agent/tools/write/write.py
@@ -14,7 +14,7 @@ class Write(BaseTool):
"""Tool for writing file content"""
name: str = "write"
- description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories."
+ description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories. IMPORTANT: Single write should not exceed 10KB. For large files, create a skeleton first, then use edit to add content in chunks."
params: dict = {
"type": "object",
diff --git a/app.py b/app.py
index 022fdd7..ad0cfca 100644
--- a/app.py
+++ b/app.py
@@ -59,6 +59,23 @@ def run():
os.environ["WECHATY_LOG"] = "warn"
start_channel(channel_name)
+
+ # 打印系统运行成功信息
+ logger.info("")
+ logger.info("=" * 50)
+ if conf().get("agent", False):
+ logger.info("✅ System started successfully!")
+ logger.info("🐮 Cow Agent is running")
+ logger.info(f" Channel: {channel_name}")
+ logger.info(f" Model: {conf().get('model', 'unknown')}")
+ logger.info(f" Workspace: {conf().get('agent_workspace', '~/cow')}")
+ else:
+ logger.info("✅ System started successfully!")
+ logger.info("🤖 ChatBot is running")
+ logger.info(f" Channel: {channel_name}")
+ logger.info(f" Model: {conf().get('model', 'unknown')}")
+ logger.info("=" * 50)
+ logger.info("")
while True:
time.sleep(1)
diff --git a/bridge/agent_bridge.py b/bridge/agent_bridge.py
index 34535bb..43aefc7 100644
--- a/bridge/agent_bridge.py
+++ b/bridge/agent_bridge.py
@@ -2,10 +2,11 @@
Agent Bridge - Integrates Agent system with existing COW bridge
"""
+import os
from typing import Optional, List
from agent.protocol import Agent, LLMModel, LLMRequest
-from bot.openai_compatible_bot import OpenAICompatibleBot
+from models.openai_compatible_bot import OpenAICompatibleBot
from bridge.bridge import Bridge
from bridge.context import Context
from bridge.reply import Reply, ReplyType
@@ -180,6 +181,7 @@ class AgentBridge:
self.agents = {} # session_id -> Agent instance mapping
self.default_agent = None # For backward compatibility (no session_id)
self.agent: Optional[Agent] = None
+ self.scheduler_initialized = False
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
"""
Create the super agent with COW integration
@@ -228,12 +230,7 @@ class AgentBridge:
# Log skill loading details
if agent.skill_manager:
- logger.info(f"[AgentBridge] SkillManager initialized:")
- logger.info(f"[AgentBridge] - Managed dir: {agent.skill_manager.managed_skills_dir}")
- logger.info(f"[AgentBridge] - Workspace dir: {agent.skill_manager.workspace_dir}")
- logger.info(f"[AgentBridge] - Total skills: {len(agent.skill_manager.skills)}")
- for skill_name in agent.skill_manager.skills.keys():
- logger.info(f"[AgentBridge] * {skill_name}")
+ logger.debug(f"[AgentBridge] SkillManager initialized with {len(agent.skill_manager.skills)} skills")
return agent
@@ -268,6 +265,21 @@ class AgentBridge:
# Get workspace from config
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
+ # Migrate API keys from config.json to environment variables (if not already set)
+ self._migrate_config_to_env(workspace_root)
+
+ # Load environment variables from secure .env file location
+ env_file = os.path.expanduser("~/.cow/.env")
+ if os.path.exists(env_file):
+ try:
+ from dotenv import load_dotenv
+ load_dotenv(env_file, override=True)
+ logger.info(f"[AgentBridge] Loaded environment variables from {env_file}")
+ except ImportError:
+ logger.warning("[AgentBridge] python-dotenv not installed, skipping .env file loading")
+ except Exception as e:
+ logger.warning(f"[AgentBridge] Failed to load .env file: {e}")
+
# Initialize workspace and create template files
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
@@ -283,14 +295,53 @@ class AgentBridge:
from agent.memory import MemoryManager, MemoryConfig
from agent.tools import MemorySearchTool, MemoryGetTool
- memory_config = MemoryConfig(
- workspace_root=workspace_root,
- embedding_provider="local", # Use local embedding (no API key needed)
- embedding_model="all-MiniLM-L6-v2"
- )
+ # 从 config.json 读取 OpenAI 配置
+ openai_api_key = conf().get("open_ai_api_key", "")
+ openai_api_base = conf().get("open_ai_api_base", "")
- # Create memory manager with the config
- memory_manager = MemoryManager(memory_config)
+ # 尝试初始化 OpenAI embedding provider
+ embedding_provider = None
+ if openai_api_key:
+ try:
+ from agent.memory import create_embedding_provider
+ embedding_provider = create_embedding_provider(
+ provider="openai",
+ model="text-embedding-3-small",
+ api_key=openai_api_key,
+ api_base=openai_api_base or "https://api.openai.com/v1"
+ )
+ logger.info(f"[AgentBridge] OpenAI embedding initialized")
+ except Exception as embed_error:
+ logger.warning(f"[AgentBridge] OpenAI embedding failed: {embed_error}")
+ logger.info(f"[AgentBridge] Using keyword-only search")
+ else:
+ logger.info(f"[AgentBridge] No OpenAI API key, using keyword-only search")
+
+ # 创建 memory config
+ memory_config = MemoryConfig(workspace_root=workspace_root)
+
+ # 创建 memory manager
+ memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
+
+ # 初始化时执行一次 sync,确保数据库有数据
+ import asyncio
+ try:
+ # 尝试在当前事件循环中执行
+ loop = asyncio.get_event_loop()
+ if loop.is_running():
+ # 如果事件循环正在运行,创建任务
+ asyncio.create_task(memory_manager.sync())
+ logger.info("[AgentBridge] Memory sync scheduled")
+ else:
+ # 如果没有运行的循环,直接执行
+ loop.run_until_complete(memory_manager.sync())
+ logger.info("[AgentBridge] Memory synced successfully")
+ except RuntimeError:
+ # 没有事件循环,创建新的
+ asyncio.run(memory_manager.sync())
+ logger.info("[AgentBridge] Memory synced successfully")
+ except Exception as e:
+ logger.warning(f"[AgentBridge] Memory sync failed: {e}")
# Create memory tools
memory_tools = [
@@ -318,7 +369,15 @@ class AgentBridge:
for tool_name in tool_manager.tool_classes.keys():
try:
- tool = tool_manager.create_tool(tool_name)
+ # Special handling for EnvConfig tool - pass agent_bridge reference
+ if tool_name == "env_config":
+ from agent.tools import EnvConfig
+ tool = EnvConfig({
+ "agent_bridge": self # Pass self reference for hot reload
+ })
+ else:
+ tool = tool_manager.create_tool(tool_name)
+
if tool:
# Apply workspace config to file operation tools
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls']:
@@ -326,12 +385,6 @@ class AgentBridge:
tool.cwd = file_config.get("cwd", tool.cwd if hasattr(tool, 'cwd') else None)
if 'memory_manager' in file_config:
tool.memory_manager = file_config['memory_manager']
- # Apply API key for bocha_search tool
- elif tool_name == 'bocha_search':
- bocha_api_key = conf().get("bocha_api_key", "")
- if bocha_api_key:
- tool.config = {"bocha_api_key": bocha_api_key}
- tool.api_key = bocha_api_key
tools.append(tool)
logger.debug(f"[AgentBridge] Loaded tool: {tool_name}")
except Exception as e:
@@ -342,6 +395,36 @@ class AgentBridge:
tools.extend(memory_tools)
logger.info(f"[AgentBridge] Added {len(memory_tools)} memory tools")
+ # Initialize scheduler service (once)
+ if not self.scheduler_initialized:
+ try:
+ from agent.tools.scheduler.integration import init_scheduler
+ if init_scheduler(self):
+ self.scheduler_initialized = True
+ logger.info("[AgentBridge] Scheduler service initialized")
+ except Exception as e:
+ logger.warning(f"[AgentBridge] Failed to initialize scheduler: {e}")
+
+ # Inject scheduler dependencies into SchedulerTool instances
+ if self.scheduler_initialized:
+ try:
+ from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
+ from agent.tools import SchedulerTool
+
+ task_store = get_task_store()
+ scheduler_service = get_scheduler_service()
+
+ for tool in tools:
+ if isinstance(tool, SchedulerTool):
+ tool.task_store = task_store
+ tool.scheduler_service = scheduler_service
+ if not tool.config:
+ tool.config = {}
+ tool.config["channel_type"] = conf().get("channel_type", "unknown")
+ logger.debug("[AgentBridge] Injected scheduler dependencies into SchedulerTool")
+ except Exception as e:
+ logger.warning(f"[AgentBridge] Failed to inject scheduler dependencies: {e}")
+
logger.info(f"[AgentBridge] Loaded {len(tools)} tools: {[t.name for t in tools]}")
# Load context files (SOUL.md, USER.md, etc.)
@@ -381,14 +464,19 @@ class AgentBridge:
logger.info("[AgentBridge] System prompt built successfully")
+ # Get cost control parameters from config
+ max_steps = conf().get("agent_max_steps", 20)
+ max_context_tokens = conf().get("agent_max_context_tokens", 50000)
+
# Create agent with configured tools and workspace
agent = self.create_agent(
system_prompt=system_prompt,
tools=tools,
- max_steps=50,
+ max_steps=max_steps,
output_mode="logger",
workspace_dir=workspace_root, # Pass workspace to agent for skills loading
- enable_skills=True # Enable skills auto-loading
+ enable_skills=True, # Enable skills auto-loading
+ max_context_tokens=max_context_tokens
)
# Attach memory manager to agent if available
@@ -410,6 +498,24 @@ class AgentBridge:
# Get workspace from config
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
+ # Migrate API keys from config.json to environment variables (if not already set)
+ self._migrate_config_to_env(workspace_root)
+
+ # Load environment variables from secure .env file location
+ env_file = os.path.expanduser("~/.cow/.env")
+ if os.path.exists(env_file):
+ try:
+ from dotenv import load_dotenv
+ load_dotenv(env_file, override=True)
+ logger.debug(f"[AgentBridge] Loaded environment variables from {env_file} for session {session_id}")
+ except ImportError:
+ logger.warning(f"[AgentBridge] python-dotenv not installed, skipping .env file loading for session {session_id}")
+ except Exception as e:
+ logger.warning(f"[AgentBridge] Failed to load .env file for session {session_id}: {e}")
+
+ # Migrate API keys from config.json to environment variables (if not already set)
+ self._migrate_config_to_env(workspace_root)
+
# Initialize workspace
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
@@ -420,23 +526,65 @@ class AgentBridge:
memory_tools = []
try:
- from agent.memory import MemoryManager, MemoryConfig
+ from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider
from agent.tools import MemorySearchTool, MemoryGetTool
- memory_config = MemoryConfig(
- workspace_root=workspace_root,
- embedding_provider="local",
- embedding_model="all-MiniLM-L6-v2"
- )
+ # 从 config.json 读取 OpenAI 配置
+ openai_api_key = conf().get("open_ai_api_key", "")
+ openai_api_base = conf().get("open_ai_api_base", "")
+
+ # 尝试初始化 OpenAI embedding provider
+ embedding_provider = None
+ if openai_api_key:
+ try:
+ embedding_provider = create_embedding_provider(
+ provider="openai",
+ model="text-embedding-3-small",
+ api_key=openai_api_key,
+ api_base=openai_api_base or "https://api.openai.com/v1"
+ )
+ logger.debug(f"[AgentBridge] OpenAI embedding initialized for session {session_id}")
+ except Exception as embed_error:
+ logger.warning(f"[AgentBridge] OpenAI embedding failed for session {session_id}: {embed_error}")
+ logger.info(f"[AgentBridge] Using keyword-only search for session {session_id}")
+ else:
+ logger.debug(f"[AgentBridge] No OpenAI API key, using keyword-only search for session {session_id}")
+
+ # 创建 memory config
+ memory_config = MemoryConfig(workspace_root=workspace_root)
+
+ # 创建 memory manager
+ memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
+
+ # 初始化时执行一次 sync,确保数据库有数据
+ import asyncio
+ try:
+ # 尝试在当前事件循环中执行
+ loop = asyncio.get_event_loop()
+ if loop.is_running():
+ # 如果事件循环正在运行,创建任务
+ asyncio.create_task(memory_manager.sync())
+ logger.debug(f"[AgentBridge] Memory sync scheduled for session {session_id}")
+ else:
+ # 如果没有运行的循环,直接执行
+ loop.run_until_complete(memory_manager.sync())
+ logger.debug(f"[AgentBridge] Memory synced successfully for session {session_id}")
+ except RuntimeError:
+ # 没有事件循环,创建新的
+ asyncio.run(memory_manager.sync())
+ logger.debug(f"[AgentBridge] Memory synced successfully for session {session_id}")
+ except Exception as sync_error:
+ logger.warning(f"[AgentBridge] Memory sync failed for session {session_id}: {sync_error}")
- memory_manager = MemoryManager(memory_config)
memory_tools = [
MemorySearchTool(memory_manager),
MemoryGetTool(memory_manager)
]
except Exception as e:
- logger.debug(f"[AgentBridge] Memory system not available for session {session_id}: {e}")
+ logger.warning(f"[AgentBridge] Memory system not available for session {session_id}: {e}")
+ import traceback
+ logger.warning(f"[AgentBridge] Memory init traceback: {traceback.format_exc()}")
# Load tools
from agent.tools import ToolManager
@@ -458,17 +606,42 @@ class AgentBridge:
tool.cwd = file_config.get("cwd", tool.cwd if hasattr(tool, 'cwd') else None)
if 'memory_manager' in file_config:
tool.memory_manager = file_config['memory_manager']
- elif tool_name == 'bocha_search':
- bocha_api_key = conf().get("bocha_api_key", "")
- if bocha_api_key:
- tool.config = {"bocha_api_key": bocha_api_key}
- tool.api_key = bocha_api_key
tools.append(tool)
except Exception as e:
logger.warning(f"[AgentBridge] Failed to load tool {tool_name} for session {session_id}: {e}")
if memory_tools:
tools.extend(memory_tools)
+
+ # Initialize scheduler service (once, if not already initialized)
+ if not self.scheduler_initialized:
+ try:
+ from agent.tools.scheduler.integration import init_scheduler
+ if init_scheduler(self):
+ self.scheduler_initialized = True
+ logger.debug(f"[AgentBridge] Scheduler service initialized for session {session_id}")
+ except Exception as e:
+ logger.warning(f"[AgentBridge] Failed to initialize scheduler for session {session_id}: {e}")
+
+ # Inject scheduler dependencies into SchedulerTool instances
+ if self.scheduler_initialized:
+ try:
+ from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
+ from agent.tools import SchedulerTool
+
+ task_store = get_task_store()
+ scheduler_service = get_scheduler_service()
+
+ for tool in tools:
+ if isinstance(tool, SchedulerTool):
+ tool.task_store = task_store
+ tool.scheduler_service = scheduler_service
+ if not tool.config:
+ tool.config = {}
+ tool.config["channel_type"] = conf().get("channel_type", "unknown")
+ logger.debug(f"[AgentBridge] Injected scheduler dependencies for session {session_id}")
+ except Exception as e:
+ logger.warning(f"[AgentBridge] Failed to inject scheduler dependencies for session {session_id}: {e}")
# Load context files
context_files = load_context_files(workspace_root)
@@ -478,7 +651,7 @@ class AgentBridge:
try:
from agent.skills import SkillManager
skill_manager = SkillManager(workspace_dir=workspace_root)
- logger.info(f"[AgentBridge] Initialized SkillManager with {len(skill_manager.skills)} skills for session {session_id}")
+ logger.debug(f"[AgentBridge] Initialized SkillManager with {len(skill_manager.skills)} skills for session {session_id}")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to initialize SkillManager for session {session_id}: {e}")
@@ -543,15 +716,20 @@ class AgentBridge:
if is_first:
mark_conversation_started(workspace_root)
+ # Get cost control parameters from config
+ max_steps = conf().get("agent_max_steps", 20)
+ max_context_tokens = conf().get("agent_max_context_tokens", 50000)
+
# Create agent for this session
agent = self.create_agent(
system_prompt=system_prompt,
tools=tools,
- max_steps=50,
+ max_steps=max_steps,
output_mode="logger",
workspace_dir=workspace_root,
skill_manager=skill_manager,
- enable_skills=True
+ enable_skills=True,
+ max_context_tokens=max_context_tokens
)
if memory_manager:
@@ -586,12 +764,52 @@ class AgentBridge:
if not agent:
return Reply(ReplyType.ERROR, "Failed to initialize super agent")
- # Use agent's run_stream method
- response = agent.run_stream(
- user_message=query,
- on_event=on_event,
- clear_history=clear_history
- )
+ # Filter tools based on context
+ original_tools = agent.tools
+ filtered_tools = original_tools
+
+ # If this is a scheduled task execution, exclude scheduler tool to prevent recursion
+ if context and context.get("is_scheduled_task"):
+ filtered_tools = [tool for tool in agent.tools if tool.name != "scheduler"]
+ agent.tools = filtered_tools
+ logger.info(f"[AgentBridge] Scheduled task execution: excluded scheduler tool ({len(filtered_tools)}/{len(original_tools)} tools)")
+ else:
+ # Attach context to scheduler tool if present
+ if context and agent.tools:
+ for tool in agent.tools:
+ if tool.name == "scheduler":
+ try:
+ from agent.tools.scheduler.integration import attach_scheduler_to_tool
+ attach_scheduler_to_tool(tool, context)
+ except Exception as e:
+ logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}")
+ break
+
+ try:
+ # Use agent's run_stream method
+ response = agent.run_stream(
+ user_message=query,
+ on_event=on_event,
+ clear_history=clear_history
+ )
+ finally:
+ # Restore original tools
+ if context and context.get("is_scheduled_task"):
+ agent.tools = original_tools
+
+ # Check if there are files to send (from read tool)
+ if hasattr(agent, 'stream_executor') and hasattr(agent.stream_executor, 'files_to_send'):
+ files_to_send = agent.stream_executor.files_to_send
+ if files_to_send:
+ # Send the first file (for now, handle one file at a time)
+ file_info = files_to_send[0]
+ logger.info(f"[AgentBridge] Sending file: {file_info.get('path')}")
+
+ # Clear files_to_send for next request
+ agent.stream_executor.files_to_send = []
+
+ # Return file reply based on file type
+ return self._create_file_reply(file_info, response, context)
return Reply(ReplyType.TEXT, response)
@@ -599,6 +817,120 @@ class AgentBridge:
logger.error(f"Agent reply error: {e}")
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")
+ def _create_file_reply(self, file_info: dict, text_response: str, context: Context = None) -> Reply:
+ """
+ Create a reply for sending files
+
+ Args:
+ file_info: File metadata from read tool
+ text_response: Text response from agent
+ context: Context object
+
+ Returns:
+ Reply object for file sending
+ """
+ file_type = file_info.get("file_type", "file")
+ file_path = file_info.get("path")
+
+ # For images, use IMAGE_URL type (channel will handle upload)
+ if file_type == "image":
+ # Convert local path to file:// URL for channel processing
+ file_url = f"file://{file_path}"
+ logger.info(f"[AgentBridge] Sending image: {file_url}")
+ reply = Reply(ReplyType.IMAGE_URL, file_url)
+ # Attach text message if present (for channels that support text+image)
+ if text_response:
+ reply.text_content = text_response # Store accompanying text
+ return reply
+
+ # For documents (PDF, Excel, Word, PPT), use FILE type
+ if file_type == "document":
+ file_url = f"file://{file_path}"
+ logger.info(f"[AgentBridge] Sending document: {file_url}")
+ reply = Reply(ReplyType.FILE, file_url)
+ reply.file_name = file_info.get("file_name", os.path.basename(file_path))
+ return reply
+
+ # For other files (video, audio), we need channel-specific handling
+ # For now, return text with file info
+ # TODO: Implement video/audio sending when channel supports it
+ message = text_response or file_info.get("message", "文件已准备")
+ message += f"\n\n[文件: {file_info.get('file_name', file_path)}]"
+ return Reply(ReplyType.TEXT, message)
+
+ def _migrate_config_to_env(self, workspace_root: str):
+ """
+ Migrate API keys from config.json to .env file if not already set
+
+ Args:
+ workspace_root: Workspace directory path (not used, kept for compatibility)
+ """
+ from config import conf
+ import os
+
+ # Mapping from config.json keys to environment variable names
+ key_mapping = {
+ "open_ai_api_key": "OPENAI_API_KEY",
+ "open_ai_api_base": "OPENAI_API_BASE",
+ "gemini_api_key": "GEMINI_API_KEY",
+ "claude_api_key": "CLAUDE_API_KEY",
+ "linkai_api_key": "LINKAI_API_KEY",
+ }
+
+ # Use fixed secure location for .env file
+ env_file = os.path.expanduser("~/.cow/.env")
+
+ # Read existing env vars from .env file
+ existing_env_vars = {}
+ if os.path.exists(env_file):
+ try:
+ with open(env_file, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ if line and not line.startswith('#') and '=' in line:
+ key, _ = line.split('=', 1)
+ existing_env_vars[key.strip()] = True
+ except Exception as e:
+ logger.warning(f"[AgentBridge] Failed to read .env file: {e}")
+
+ # Check which keys need to be migrated
+ keys_to_migrate = {}
+ for config_key, env_key in key_mapping.items():
+ # Skip if already in .env file
+ if env_key in existing_env_vars:
+ continue
+
+ # Get value from config.json
+ value = conf().get(config_key, "")
+ if value and value.strip(): # Only migrate non-empty values
+ keys_to_migrate[env_key] = value.strip()
+
+ # Log summary if there are keys to skip
+ if existing_env_vars:
+ logger.debug(f"[AgentBridge] {len(existing_env_vars)} env vars already in .env")
+
+ # Write new keys to .env file
+ if keys_to_migrate:
+ try:
+ # Ensure ~/.cow directory and .env file exist
+ env_dir = os.path.dirname(env_file)
+ if not os.path.exists(env_dir):
+ os.makedirs(env_dir, exist_ok=True)
+ if not os.path.exists(env_file):
+ open(env_file, 'a').close()
+
+ # Append new keys
+ with open(env_file, 'a', encoding='utf-8') as f:
+ f.write('\n# Auto-migrated from config.json\n')
+ for key, value in keys_to_migrate.items():
+ f.write(f'{key}={value}\n')
+ # Also set in current process
+ os.environ[key] = value
+
+ logger.info(f"[AgentBridge] Migrated {len(keys_to_migrate)} API keys from config.json to .env: {list(keys_to_migrate.keys())}")
+ except Exception as e:
+ logger.warning(f"[AgentBridge] Failed to migrate API keys: {e}")
+
def clear_session(self, session_id: str):
"""
Clear a specific session's agent and conversation history
@@ -614,4 +946,43 @@ class AgentBridge:
"""Clear all agent sessions"""
logger.info(f"[AgentBridge] Clearing all sessions ({len(self.agents)} total)")
self.agents.clear()
- self.default_agent = None
\ No newline at end of file
+ self.default_agent = None
+
+ def refresh_all_skills(self) -> int:
+ """
+ Refresh skills in all agent instances after environment variable changes.
+ This allows hot-reload of skills without restarting the agent.
+
+ Returns:
+ Number of agent instances refreshed
+ """
+ import os
+ from dotenv import load_dotenv
+ from config import conf
+
+ # Reload environment variables from .env file
+ workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
+ env_file = os.path.join(workspace_root, '.env')
+
+ if os.path.exists(env_file):
+ load_dotenv(env_file, override=True)
+ logger.info(f"[AgentBridge] Reloaded environment variables from {env_file}")
+
+ refreshed_count = 0
+
+ # Refresh default agent
+ if self.default_agent and hasattr(self.default_agent, 'skill_manager'):
+ self.default_agent.skill_manager.refresh_skills()
+ refreshed_count += 1
+ logger.info("[AgentBridge] Refreshed skills in default agent")
+
+ # Refresh all session agents
+ for session_id, agent in self.agents.items():
+ if hasattr(agent, 'skill_manager'):
+ agent.skill_manager.refresh_skills()
+ refreshed_count += 1
+
+ if refreshed_count > 0:
+ logger.info(f"[AgentBridge] Refreshed skills in {refreshed_count} agent instance(s)")
+
+ return refreshed_count
\ No newline at end of file
diff --git a/bridge/bridge.py b/bridge/bridge.py
index a7b93c4..4c686f9 100644
--- a/bridge/bridge.py
+++ b/bridge/bridge.py
@@ -1,4 +1,4 @@
-from bot.bot_factory import create_bot
+from models.bot_factory import create_bot
from bridge.context import Context
from bridge.reply import Reply
from common import const
diff --git a/channel/chat_channel.py b/channel/chat_channel.py
index 1523f67..af3607d 100644
--- a/channel/chat_channel.py
+++ b/channel/chat_channel.py
@@ -64,15 +64,22 @@ class ChatChannel(Channel):
check_contain(group_name, group_name_keyword_white_list),
]
):
- group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
- session_id = cmsg.actual_user_id
- if any(
- [
- group_name in group_chat_in_one_session,
- "ALL_GROUP" in group_chat_in_one_session,
- ]
- ):
+ # Check global group_shared_session config first
+ group_shared_session = conf().get("group_shared_session", True)
+ if group_shared_session:
+ # All users in the group share the same session
session_id = group_id
+ else:
+ # Check group-specific whitelist (legacy behavior)
+ group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
+ session_id = cmsg.actual_user_id
+ if any(
+ [
+ group_name in group_chat_in_one_session,
+ "ALL_GROUP" in group_chat_in_one_session,
+ ]
+ ):
+ session_id = group_id
else:
logger.debug(f"No need reply, groupName not in whitelist, group_name={group_name}")
return None
@@ -166,11 +173,11 @@ class ChatChannel(Channel):
def _handle(self, context: Context):
if context is None or not context.content:
return
- logger.debug("[chat_channel] ready to handle context: {}".format(context))
+ logger.debug("[chat_channel] handling context: {}".format(context))
# reply的构建步骤
reply = self._generate_reply(context)
- logger.debug("[chat_channel] ready to decorate reply: {}".format(reply))
+ logger.debug("[chat_channel] decorating reply: {}".format(reply))
# reply的包装步骤
if reply and reply.content:
@@ -188,7 +195,7 @@ class ChatChannel(Channel):
)
reply = e_context["reply"]
if not e_context.is_pass():
- logger.debug("[chat_channel] ready to handle context: type={}, content={}".format(context.type, context.content))
+ logger.debug("[chat_channel] type={}, content={}".format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
context["channel"] = e_context["channel"]
reply = super().build_reply_content(context.content, context)
@@ -282,7 +289,100 @@ class ChatChannel(Channel):
)
reply = e_context["reply"]
if not e_context.is_pass() and reply and reply.type:
- logger.debug("[chat_channel] ready to send reply: {}, context: {}".format(reply, context))
+ logger.debug("[chat_channel] sending reply: {}, context: {}".format(reply, context))
+
+ # 如果是文本回复,尝试提取并发送图片
+ if reply.type == ReplyType.TEXT:
+ self._extract_and_send_images(reply, context)
+ # 如果是图片回复但带有文本内容,先发文本再发图片
+ elif reply.type == ReplyType.IMAGE_URL and hasattr(reply, 'text_content') and reply.text_content:
+ # 先发送文本
+ text_reply = Reply(ReplyType.TEXT, reply.text_content)
+ self._send(text_reply, context)
+ # 短暂延迟后发送图片
+ time.sleep(0.3)
+ self._send(reply, context)
+ else:
+ self._send(reply, context)
+
+ def _extract_and_send_images(self, reply: Reply, context: Context):
+ """
+ 从文本回复中提取图片/视频URL并单独发送
+ 支持格式:[图片: /path/to/image.png], [视频: /path/to/video.mp4], ,
+ 最多发送5个媒体文件
+ """
+ content = reply.content
+ media_items = [] # [(url, type), ...]
+
+ # 正则提取各种格式的媒体URL
+ patterns = [
+ (r'\[图片:\s*([^\]]+)\]', 'image'), # [图片: /path/to/image.png]
+ (r'\[视频:\s*([^\]]+)\]', 'video'), # [视频: /path/to/video.mp4]
+ (r'!\[.*?\]\(([^\)]+)\)', 'image'), #  - 默认图片
+ (r'
]+src=["\']([^"\']+)["\']', 'image'), #
+ (r'