Merge pull request #2648 from zhayujie/feat-cow-agent

feat: cow agent core
This commit is contained in:
Saboteur7
2026-01-31 13:14:05 +08:00
committed by GitHub
85 changed files with 12520 additions and 372 deletions

3
.gitignore vendored
View File

@@ -35,3 +35,6 @@ plugins/banwords/lib/__pycache__
!plugins/linkai
!plugins/agent
client_config.json
ref/
.cursor/
local/

10
agent/memory/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
"""
Memory module for AgentMesh
Provides long-term memory capabilities with hybrid search (vector + keyword)
"""
from agent.memory.manager import MemoryManager
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config']

139
agent/memory/chunker.py Normal file
View File

@@ -0,0 +1,139 @@
"""
Text chunking utilities for memory
Splits text into chunks with token limits and overlap
"""
from typing import List, Tuple
from dataclasses import dataclass
@dataclass
class TextChunk:
"""Represents a text chunk with line numbers"""
text: str
start_line: int
end_line: int
class TextChunker:
"""Chunks text by line count with token estimation"""
def __init__(self, max_tokens: int = 500, overlap_tokens: int = 50):
"""
Initialize chunker
Args:
max_tokens: Maximum tokens per chunk
overlap_tokens: Overlap tokens between chunks
"""
self.max_tokens = max_tokens
self.overlap_tokens = overlap_tokens
# Rough estimation: ~4 chars per token for English/Chinese mixed
self.chars_per_token = 4
def chunk_text(self, text: str) -> List[TextChunk]:
"""
Chunk text into overlapping segments
Args:
text: Input text to chunk
Returns:
List of TextChunk objects
"""
if not text.strip():
return []
lines = text.split('\n')
chunks = []
max_chars = self.max_tokens * self.chars_per_token
overlap_chars = self.overlap_tokens * self.chars_per_token
current_chunk = []
current_chars = 0
start_line = 1
for i, line in enumerate(lines, start=1):
line_chars = len(line)
# If single line exceeds max, split it
if line_chars > max_chars:
# Save current chunk if exists
if current_chunk:
chunks.append(TextChunk(
text='\n'.join(current_chunk),
start_line=start_line,
end_line=i - 1
))
current_chunk = []
current_chars = 0
# Split long line into multiple chunks
for sub_chunk in self._split_long_line(line, max_chars):
chunks.append(TextChunk(
text=sub_chunk,
start_line=i,
end_line=i
))
start_line = i + 1
continue
# Check if adding this line would exceed limit
if current_chars + line_chars > max_chars and current_chunk:
# Save current chunk
chunks.append(TextChunk(
text='\n'.join(current_chunk),
start_line=start_line,
end_line=i - 1
))
# Start new chunk with overlap
overlap_lines = self._get_overlap_lines(current_chunk, overlap_chars)
current_chunk = overlap_lines + [line]
current_chars = sum(len(l) for l in current_chunk)
start_line = i - len(overlap_lines)
else:
# Add line to current chunk
current_chunk.append(line)
current_chars += line_chars
# Save last chunk
if current_chunk:
chunks.append(TextChunk(
text='\n'.join(current_chunk),
start_line=start_line,
end_line=len(lines)
))
return chunks
def _split_long_line(self, line: str, max_chars: int) -> List[str]:
"""Split a single long line into multiple chunks"""
chunks = []
for i in range(0, len(line), max_chars):
chunks.append(line[i:i + max_chars])
return chunks
def _get_overlap_lines(self, lines: List[str], target_chars: int) -> List[str]:
"""Get last few lines that fit within target_chars for overlap"""
overlap = []
chars = 0
for line in reversed(lines):
line_chars = len(line)
if chars + line_chars > target_chars:
break
overlap.insert(0, line)
chars += line_chars
return overlap
def chunk_markdown(self, text: str) -> List[TextChunk]:
"""
Chunk markdown text while respecting structure
(For future enhancement: respect markdown sections)
"""
return self.chunk_text(text)

114
agent/memory/config.py Normal file
View File

@@ -0,0 +1,114 @@
"""
Memory configuration module
Provides global memory configuration with simplified workspace structure
"""
import os
from dataclasses import dataclass, field
from typing import Optional, List
from pathlib import Path
@dataclass
class MemoryConfig:
"""Configuration for memory storage and search"""
# Storage paths (default: ~/cow)
workspace_root: str = field(default_factory=lambda: os.path.expanduser("~/cow"))
# Embedding config
embedding_provider: str = "openai" # "openai" | "local"
embedding_model: str = "text-embedding-3-small"
embedding_dim: int = 1536
# Chunking config
chunk_max_tokens: int = 500
chunk_overlap_tokens: int = 50
# Search config
max_results: int = 10
min_score: float = 0.1
# Hybrid search weights
vector_weight: float = 0.7
keyword_weight: float = 0.3
# Memory sources
sources: List[str] = field(default_factory=lambda: ["memory", "session"])
# Sync config
enable_auto_sync: bool = True
sync_on_search: bool = True
def get_workspace(self) -> Path:
"""Get workspace root directory"""
return Path(self.workspace_root)
def get_memory_dir(self) -> Path:
"""Get memory files directory"""
return self.get_workspace() / "memory"
def get_db_path(self) -> Path:
"""Get SQLite database path for long-term memory index"""
index_dir = self.get_memory_dir() / "long-term"
index_dir.mkdir(parents=True, exist_ok=True)
return index_dir / "index.db"
def get_skills_dir(self) -> Path:
"""Get skills directory"""
return self.get_workspace() / "skills"
def get_agent_workspace(self, agent_name: Optional[str] = None) -> Path:
"""
Get workspace directory for an agent
Args:
agent_name: Optional agent name (not used in current implementation)
Returns:
Path to workspace directory
"""
workspace = self.get_workspace()
# Ensure workspace directory exists
workspace.mkdir(parents=True, exist_ok=True)
return workspace
# Global memory configuration
_global_memory_config: Optional[MemoryConfig] = None
def get_default_memory_config() -> MemoryConfig:
"""
Get the global memory configuration.
If not set, returns a default configuration.
Returns:
MemoryConfig instance
"""
global _global_memory_config
if _global_memory_config is None:
_global_memory_config = MemoryConfig()
return _global_memory_config
def set_global_memory_config(config: MemoryConfig):
"""
Set the global memory configuration.
This should be called before creating any MemoryManager instances.
Args:
config: MemoryConfig instance to use globally
Example:
>>> from agent.memory import MemoryConfig, set_global_memory_config
>>> config = MemoryConfig(
... workspace_root="~/my_agents",
... embedding_provider="openai",
... vector_weight=0.8
... )
>>> set_global_memory_config(config)
"""
global _global_memory_config
_global_memory_config = config

175
agent/memory/embedding.py Normal file
View File

@@ -0,0 +1,175 @@
"""
Embedding providers for memory
Supports OpenAI and local embedding models
"""
from typing import List, Optional
from abc import ABC, abstractmethod
import hashlib
import json
class EmbeddingProvider(ABC):
"""Base class for embedding providers"""
@abstractmethod
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
pass
@abstractmethod
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
pass
@property
@abstractmethod
def dimensions(self) -> int:
"""Get embedding dimensions"""
pass
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI embedding provider"""
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
"""
Initialize OpenAI embedding provider
Args:
model: Model name (text-embedding-3-small or text-embedding-3-large)
api_key: OpenAI API key
api_base: Optional API base URL
"""
self.model = model
self.api_key = api_key
self.api_base = api_base or "https://api.openai.com/v1"
# Lazy import to avoid dependency issues
try:
from openai import OpenAI
self.client = OpenAI(api_key=api_key, base_url=api_base)
except ImportError:
raise ImportError("OpenAI package not installed. Install with: pip install openai")
# Set dimensions based on model
self._dimensions = 1536 if "small" in model else 3072
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
response = self.client.embeddings.create(
input=text,
model=self.model
)
return response.data[0].embedding
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
if not texts:
return []
response = self.client.embeddings.create(
input=texts,
model=self.model
)
return [item.embedding for item in response.data]
@property
def dimensions(self) -> int:
return self._dimensions
class LocalEmbeddingProvider(EmbeddingProvider):
"""Local embedding provider using sentence-transformers"""
def __init__(self, model: str = "all-MiniLM-L6-v2"):
"""
Initialize local embedding provider
Args:
model: Model name from sentence-transformers
"""
self.model_name = model
try:
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model)
self._dimensions = self.model.get_sentence_embedding_dimension()
except ImportError:
raise ImportError(
"sentence-transformers not installed. "
"Install with: pip install sentence-transformers"
)
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
embedding = self.model.encode(text, convert_to_numpy=True)
return embedding.tolist()
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
if not texts:
return []
embeddings = self.model.encode(texts, convert_to_numpy=True)
return embeddings.tolist()
@property
def dimensions(self) -> int:
return self._dimensions
class EmbeddingCache:
"""Cache for embeddings to avoid recomputation"""
def __init__(self):
self.cache = {}
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
"""Get cached embedding"""
key = self._compute_key(text, provider, model)
return self.cache.get(key)
def put(self, text: str, provider: str, model: str, embedding: List[float]):
"""Cache embedding"""
key = self._compute_key(text, provider, model)
self.cache[key] = embedding
@staticmethod
def _compute_key(text: str, provider: str, model: str) -> str:
"""Compute cache key"""
content = f"{provider}:{model}:{text}"
return hashlib.md5(content.encode('utf-8')).hexdigest()
def clear(self):
"""Clear cache"""
self.cache.clear()
def create_embedding_provider(
provider: str = "openai",
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None
) -> EmbeddingProvider:
"""
Factory function to create embedding provider
Args:
provider: Provider name ("openai" or "local")
model: Model name (provider-specific)
api_key: API key for remote providers
api_base: API base URL for remote providers
Returns:
EmbeddingProvider instance
"""
if provider == "openai":
model = model or "text-embedding-3-small"
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
elif provider == "local":
model = model or "all-MiniLM-L6-v2"
return LocalEmbeddingProvider(model=model)
else:
raise ValueError(f"Unknown embedding provider: {provider}")

613
agent/memory/manager.py Normal file
View File

@@ -0,0 +1,613 @@
"""
Memory manager for AgentMesh
Provides high-level interface for memory operations
"""
import os
from typing import List, Optional, Dict, Any
from pathlib import Path
import hashlib
from datetime import datetime, timedelta
from agent.memory.config import MemoryConfig, get_default_memory_config
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
from agent.memory.chunker import TextChunker
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
class MemoryManager:
"""
Memory manager with hybrid search capabilities
Provides long-term memory for agents with vector and keyword search
"""
def __init__(
self,
config: Optional[MemoryConfig] = None,
embedding_provider: Optional[EmbeddingProvider] = None,
llm_model: Optional[Any] = None
):
"""
Initialize memory manager
Args:
config: Memory configuration (uses global config if not provided)
embedding_provider: Custom embedding provider (optional)
llm_model: LLM model for summarization (optional)
"""
self.config = config or get_default_memory_config()
# Initialize storage
db_path = self.config.get_db_path()
self.storage = MemoryStorage(db_path)
# Initialize chunker
self.chunker = TextChunker(
max_tokens=self.config.chunk_max_tokens,
overlap_tokens=self.config.chunk_overlap_tokens
)
# Initialize embedding provider (optional)
self.embedding_provider = None
if embedding_provider:
self.embedding_provider = embedding_provider
else:
# Try to create embedding provider, but allow failure
try:
# Get API key from environment or config
api_key = os.environ.get('OPENAI_API_KEY')
api_base = os.environ.get('OPENAI_API_BASE')
self.embedding_provider = create_embedding_provider(
provider=self.config.embedding_provider,
model=self.config.embedding_model,
api_key=api_key,
api_base=api_base
)
except Exception as e:
# Embedding provider failed, but that's OK
# We can still use keyword search and file operations
print(f"⚠️ Warning: Embedding provider initialization failed: {e}")
print(f" Memory will work with keyword search only (no semantic search)")
# Initialize memory flush manager
workspace_dir = self.config.get_workspace()
self.flush_manager = MemoryFlushManager(
workspace_dir=workspace_dir,
llm_model=llm_model
)
# Ensure workspace directories exist
self._init_workspace()
self._dirty = False
def _init_workspace(self):
"""Initialize workspace directories"""
memory_dir = self.config.get_memory_dir()
memory_dir.mkdir(parents=True, exist_ok=True)
# Create default memory files
workspace_dir = self.config.get_workspace()
create_memory_files_if_needed(workspace_dir)
async def search(
self,
query: str,
user_id: Optional[str] = None,
max_results: Optional[int] = None,
min_score: Optional[float] = None,
include_shared: bool = True
) -> List[SearchResult]:
"""
Search memory with hybrid search (vector + keyword)
Args:
query: Search query
user_id: User ID for scoped search
max_results: Maximum results to return
min_score: Minimum score threshold
include_shared: Include shared memories
Returns:
List of search results sorted by relevance
"""
max_results = max_results or self.config.max_results
min_score = min_score or self.config.min_score
# Determine scopes
scopes = []
if include_shared:
scopes.append("shared")
if user_id:
scopes.append("user")
if not scopes:
return []
# Sync if needed
if self.config.sync_on_search and self._dirty:
await self.sync()
# Perform vector search (if embedding provider available)
vector_results = []
if self.embedding_provider:
query_embedding = self.embedding_provider.embed(query)
vector_results = self.storage.search_vector(
query_embedding=query_embedding,
user_id=user_id,
scopes=scopes,
limit=max_results * 2 # Get more candidates for merging
)
# Perform keyword search
keyword_results = self.storage.search_keyword(
query=query,
user_id=user_id,
scopes=scopes,
limit=max_results * 2
)
# Merge results
merged = self._merge_results(
vector_results,
keyword_results,
self.config.vector_weight,
self.config.keyword_weight
)
# Filter by min score and limit
filtered = [r for r in merged if r.score >= min_score]
return filtered[:max_results]
async def add_memory(
self,
content: str,
user_id: Optional[str] = None,
scope: str = "shared",
source: str = "memory",
path: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
):
"""
Add new memory content
Args:
content: Memory content
user_id: User ID for user-scoped memory
scope: Memory scope ("shared", "user", "session")
source: Memory source ("memory" or "session")
path: File path (auto-generated if not provided)
metadata: Additional metadata
"""
if not content.strip():
return
# Generate path if not provided
if not path:
content_hash = hashlib.md5(content.encode('utf-8')).hexdigest()[:8]
if user_id and scope == "user":
path = f"memory/users/{user_id}/memory_{content_hash}.md"
else:
path = f"memory/shared/memory_{content_hash}.md"
# Chunk content
chunks = self.chunker.chunk_text(content)
# Generate embeddings (if provider available)
texts = [chunk.text for chunk in chunks]
if self.embedding_provider:
embeddings = self.embedding_provider.embed_batch(texts)
else:
# No embeddings, just use None
embeddings = [None] * len(texts)
# Create memory chunks
memory_chunks = []
for chunk, embedding in zip(chunks, embeddings):
chunk_id = self._generate_chunk_id(path, chunk.start_line, chunk.end_line)
chunk_hash = MemoryStorage.compute_hash(chunk.text)
memory_chunks.append(MemoryChunk(
id=chunk_id,
user_id=user_id,
scope=scope,
source=source,
path=path,
start_line=chunk.start_line,
end_line=chunk.end_line,
text=chunk.text,
embedding=embedding,
hash=chunk_hash,
metadata=metadata
))
# Save to storage
self.storage.save_chunks_batch(memory_chunks)
# Update file metadata
file_hash = MemoryStorage.compute_hash(content)
self.storage.update_file_metadata(
path=path,
source=source,
file_hash=file_hash,
mtime=int(os.path.getmtime(__file__)), # Use current time
size=len(content)
)
async def sync(self, force: bool = False):
"""
Synchronize memory from files
Args:
force: Force full reindex
"""
memory_dir = self.config.get_memory_dir()
workspace_dir = self.config.get_workspace()
# Scan MEMORY.md (workspace root)
memory_file = Path(workspace_dir) / "MEMORY.md"
if memory_file.exists():
await self._sync_file(memory_file, "memory", "shared", None)
# Scan memory directory (including daily summaries)
if memory_dir.exists():
for file_path in memory_dir.rglob("*.md"):
# Determine scope and user_id from path
rel_path = file_path.relative_to(workspace_dir)
parts = rel_path.parts
# Check if it's in daily summary directory
if "daily" in parts:
# Daily summary files
if "users" in parts or len(parts) > 3:
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
user_idx = parts.index("daily") + 1
user_id = parts[user_idx] if user_idx < len(parts) else None
scope = "user"
else:
# Shared daily summary: memory/daily/2024-01-29.md
user_id = None
scope = "shared"
elif "users" in parts:
# User-scoped memory
user_idx = parts.index("users") + 1
user_id = parts[user_idx] if user_idx < len(parts) else None
scope = "user"
else:
# Shared memory
user_id = None
scope = "shared"
await self._sync_file(file_path, "memory", scope, user_id)
self._dirty = False
async def _sync_file(
self,
file_path: Path,
source: str,
scope: str,
user_id: Optional[str]
):
"""Sync a single file"""
# Compute file hash
content = file_path.read_text()
file_hash = MemoryStorage.compute_hash(content)
# Get relative path
workspace_dir = self.config.get_workspace()
rel_path = str(file_path.relative_to(workspace_dir))
# Check if file changed
stored_hash = self.storage.get_file_hash(rel_path)
if stored_hash == file_hash:
return # No changes
# Delete old chunks
self.storage.delete_by_path(rel_path)
# Chunk and embed
chunks = self.chunker.chunk_text(content)
if not chunks:
return
texts = [chunk.text for chunk in chunks]
if self.embedding_provider:
embeddings = self.embedding_provider.embed_batch(texts)
else:
embeddings = [None] * len(texts)
# Create memory chunks
memory_chunks = []
for chunk, embedding in zip(chunks, embeddings):
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
chunk_hash = MemoryStorage.compute_hash(chunk.text)
memory_chunks.append(MemoryChunk(
id=chunk_id,
user_id=user_id,
scope=scope,
source=source,
path=rel_path,
start_line=chunk.start_line,
end_line=chunk.end_line,
text=chunk.text,
embedding=embedding,
hash=chunk_hash,
metadata=None
))
# Save
self.storage.save_chunks_batch(memory_chunks)
# Update file metadata
stat = file_path.stat()
self.storage.update_file_metadata(
path=rel_path,
source=source,
file_hash=file_hash,
mtime=int(stat.st_mtime),
size=stat.st_size
)
def should_flush_memory(
self,
current_tokens: int,
context_window: int = 128000,
reserve_tokens: int = 20000,
soft_threshold: int = 4000
) -> bool:
"""
Check if memory flush should be triggered
Args:
current_tokens: Current session token count
context_window: Model's context window size (default: 128K)
reserve_tokens: Reserve tokens for compaction overhead (default: 20K)
soft_threshold: Trigger N tokens before threshold (default: 4K)
Returns:
True if memory flush should run
"""
return self.flush_manager.should_flush(
current_tokens=current_tokens,
context_window=context_window,
reserve_tokens=reserve_tokens,
soft_threshold=soft_threshold
)
async def execute_memory_flush(
self,
agent_executor,
current_tokens: int,
user_id: Optional[str] = None,
**executor_kwargs
) -> bool:
"""
Execute memory flush before compaction
This runs a silent agent turn to write durable memories to disk.
Similar to clawdbot's pre-compaction memory flush.
Args:
agent_executor: Async function to execute agent with prompt
current_tokens: Current session token count
user_id: Optional user ID
**executor_kwargs: Additional kwargs for agent executor
Returns:
True if flush completed successfully
Example:
>>> async def run_agent(prompt, system_prompt, silent=False):
... # Your agent execution logic
... pass
>>>
>>> if manager.should_flush_memory(current_tokens=100000):
... await manager.execute_memory_flush(
... agent_executor=run_agent,
... current_tokens=100000
... )
"""
success = await self.flush_manager.execute_flush(
agent_executor=agent_executor,
current_tokens=current_tokens,
user_id=user_id,
**executor_kwargs
)
if success:
# Mark dirty so next search will sync the new memories
self._dirty = True
return success
def build_memory_guidance(self, lang: str = "zh", include_context: bool = True) -> str:
"""
Build natural memory guidance for agent system prompt
Following clawdbot's approach:
1. Load MEMORY.md as bootstrap context (blends into background)
2. Load daily files on-demand via memory_search tool
3. Agent should NOT proactively mention memories unless user asks
Args:
lang: Language for guidance ("en" or "zh")
include_context: Whether to include bootstrap memory context (default: True)
MEMORY.md is loaded as background context (like clawdbot)
Daily files are accessed via memory_search tool
Returns:
Memory guidance text (and optionally context) for system prompt
"""
today_file = self.flush_manager.get_today_memory_file().name
if lang == "zh":
guidance = f"""## 记忆系统
**背景知识**: 下方包含核心长期记忆,可直接使用。需要查找历史时,用 memory_search 搜索(搜索一次即可,不要重复)。
**存储记忆**: 当用户分享重要信息时(偏好、决策、事实等),主动用 write 工具存储:
- 长期信息 → MEMORY.md
- 当天笔记 → memory/{today_file}
- 静默存储,仅在明确要求时确认
**使用原则**: 自然使用记忆,就像你本来就知道。不要主动提起或列举记忆,除非用户明确询问。"""
else:
guidance = f"""## Memory System
**Background Knowledge**: Core long-term memories below - use directly. For history, use memory_search once (don't repeat).
**Store Memories**: When user shares important info (preferences, decisions, facts), proactively write:
- Durable info → MEMORY.md
- Daily notes → memory/{today_file}
- Store silently; confirm only when explicitly requested
**Usage**: Use memories naturally as if you always knew. Don't mention or list unless user explicitly asks."""
if include_context:
# Load bootstrap context (MEMORY.md only, like clawdbot)
bootstrap_context = self.load_bootstrap_memories()
if bootstrap_context:
guidance += f"\n\n## Background Context\n\n{bootstrap_context}"
return guidance
def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str:
"""
Load bootstrap memory files for session start
Following clawdbot's design:
- Only loads MEMORY.md from workspace root (long-term curated memory)
- Daily files (memory/YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap
- User-specific MEMORY.md is also loaded if user_id provided
Returns memory content WITHOUT obvious headers so it blends naturally
into the context as background knowledge.
Args:
user_id: Optional user ID for user-specific memories
Returns:
Memory content to inject into system prompt (blends naturally as background context)
"""
workspace_dir = self.config.get_workspace()
memory_dir = self.config.get_memory_dir()
sections = []
# 1. Load MEMORY.md from workspace root (long-term curated memory)
# Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search
memory_file = Path(workspace_dir) / "MEMORY.md"
if memory_file.exists():
try:
content = memory_file.read_text(encoding='utf-8').strip()
if content:
sections.append(content)
except Exception as e:
print(f"Warning: Failed to read MEMORY.md: {e}")
# 2. Load user-specific MEMORY.md if user_id provided
if user_id:
user_memory_dir = memory_dir / "users" / user_id
user_memory_file = user_memory_dir / "MEMORY.md"
if user_memory_file.exists():
try:
content = user_memory_file.read_text(encoding='utf-8').strip()
if content:
sections.append(content)
except Exception as e:
print(f"Warning: Failed to read user memory: {e}")
if not sections:
return ""
# Join sections without obvious headers - let memories blend naturally
# This makes the agent feel like it "just knows" rather than "checking memory files"
return "\n\n".join(sections)
def get_status(self) -> Dict[str, Any]:
"""Get memory status"""
stats = self.storage.get_stats()
return {
'chunks': stats['chunks'],
'files': stats['files'],
'workspace': str(self.config.get_workspace()),
'dirty': self._dirty,
'embedding_enabled': self.embedding_provider is not None,
'embedding_provider': self.config.embedding_provider if self.embedding_provider else 'disabled',
'embedding_model': self.config.embedding_model if self.embedding_provider else 'N/A',
'search_mode': 'hybrid (vector + keyword)' if self.embedding_provider else 'keyword only (FTS5)'
}
def mark_dirty(self):
"""Mark memory as dirty (needs sync)"""
self._dirty = True
def close(self):
"""Close memory manager and release resources"""
self.storage.close()
# Helper methods
def _generate_chunk_id(self, path: str, start_line: int, end_line: int) -> str:
"""Generate unique chunk ID"""
content = f"{path}:{start_line}:{end_line}"
return hashlib.md5(content.encode('utf-8')).hexdigest()
def _merge_results(
self,
vector_results: List[SearchResult],
keyword_results: List[SearchResult],
vector_weight: float,
keyword_weight: float
) -> List[SearchResult]:
"""Merge vector and keyword search results"""
# Create a map by (path, start_line, end_line)
merged_map = {}
for result in vector_results:
key = (result.path, result.start_line, result.end_line)
merged_map[key] = {
'result': result,
'vector_score': result.score,
'keyword_score': 0.0
}
for result in keyword_results:
key = (result.path, result.start_line, result.end_line)
if key in merged_map:
merged_map[key]['keyword_score'] = result.score
else:
merged_map[key] = {
'result': result,
'vector_score': 0.0,
'keyword_score': result.score
}
# Calculate combined scores
merged_results = []
for entry in merged_map.values():
combined_score = (
vector_weight * entry['vector_score'] +
keyword_weight * entry['keyword_score']
)
result = entry['result']
merged_results.append(SearchResult(
path=result.path,
start_line=result.start_line,
end_line=result.end_line,
score=combined_score,
snippet=result.snippet,
source=result.source,
user_id=result.user_id
))
# Sort by score
merged_results.sort(key=lambda r: r.score, reverse=True)
return merged_results

568
agent/memory/storage.py Normal file
View File

@@ -0,0 +1,568 @@
"""
Storage layer for memory using SQLite + FTS5
Provides vector and keyword search capabilities
"""
import sqlite3
import json
import hashlib
from typing import List, Dict, Optional, Any
from pathlib import Path
from dataclasses import dataclass
@dataclass
class MemoryChunk:
"""Represents a memory chunk with text and embedding"""
id: str
user_id: Optional[str]
scope: str # "shared" | "user" | "session"
source: str # "memory" | "session"
path: str
start_line: int
end_line: int
text: str
embedding: Optional[List[float]]
hash: str
metadata: Optional[Dict[str, Any]] = None
@dataclass
class SearchResult:
"""Search result with score and snippet"""
path: str
start_line: int
end_line: int
score: float
snippet: str
source: str
user_id: Optional[str] = None
class MemoryStorage:
"""SQLite-based storage with FTS5 for keyword search"""
def __init__(self, db_path: Path):
self.db_path = db_path
self.conn: Optional[sqlite3.Connection] = None
self._init_db()
def _init_db(self):
"""Initialize database with schema"""
try:
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
# Check database integrity
try:
result = self.conn.execute("PRAGMA integrity_check").fetchone()
if result[0] != 'ok':
print(f"⚠️ Database integrity check failed: {result[0]}")
print(f" Recreating database...")
self.conn.close()
self.conn = None
# Remove corrupted database
self.db_path.unlink(missing_ok=True)
# Remove WAL files
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
# Reconnect to create new database
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
except sqlite3.DatabaseError:
# Database is corrupted, recreate it
print(f"⚠️ Database is corrupted, recreating...")
if self.conn:
self.conn.close()
self.conn = None
self.db_path.unlink(missing_ok=True)
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
# Enable WAL mode for better concurrency
self.conn.execute("PRAGMA journal_mode=WAL")
# Set busy timeout to avoid "database is locked" errors
self.conn.execute("PRAGMA busy_timeout=5000")
except Exception as e:
print(f"⚠️ Unexpected error during database initialization: {e}")
raise
# Create chunks table with embeddings
self.conn.execute("""
CREATE TABLE IF NOT EXISTS chunks (
id TEXT PRIMARY KEY,
user_id TEXT,
scope TEXT NOT NULL DEFAULT 'shared',
source TEXT NOT NULL DEFAULT 'memory',
path TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL,
text TEXT NOT NULL,
embedding TEXT,
hash TEXT NOT NULL,
metadata TEXT,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
)
""")
# Create indexes
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_user
ON chunks(user_id)
""")
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_scope
ON chunks(scope)
""")
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_hash
ON chunks(path, hash)
""")
# Create FTS5 virtual table for keyword search
# Use default unicode61 tokenizer (stable and compatible)
# For CJK support, we'll use LIKE queries as fallback
self.conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
text,
id UNINDEXED,
user_id UNINDEXED,
path UNINDEXED,
source UNINDEXED,
scope UNINDEXED,
content='chunks',
content_rowid='rowid'
)
""")
# Create triggers to keep FTS in sync
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
DELETE FROM chunks_fts WHERE rowid = old.rowid;
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
UPDATE chunks_fts SET text = new.text, id = new.id,
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
WHERE rowid = new.rowid;
END
""")
# Create files metadata table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS files (
path TEXT PRIMARY KEY,
source TEXT NOT NULL DEFAULT 'memory',
hash TEXT NOT NULL,
mtime INTEGER NOT NULL,
size INTEGER NOT NULL,
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
)
""")
self.conn.commit()
def save_chunk(self, chunk: MemoryChunk):
"""Save a memory chunk"""
self.conn.execute("""
INSERT OR REPLACE INTO chunks
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
""", (
chunk.id,
chunk.user_id,
chunk.scope,
chunk.source,
chunk.path,
chunk.start_line,
chunk.end_line,
chunk.text,
json.dumps(chunk.embedding) if chunk.embedding else None,
chunk.hash,
json.dumps(chunk.metadata) if chunk.metadata else None
))
self.conn.commit()
def save_chunks_batch(self, chunks: List[MemoryChunk]):
"""Save multiple chunks in a batch"""
self.conn.executemany("""
INSERT OR REPLACE INTO chunks
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
""", [
(
c.id, c.user_id, c.scope, c.source, c.path,
c.start_line, c.end_line, c.text,
json.dumps(c.embedding) if c.embedding else None,
c.hash,
json.dumps(c.metadata) if c.metadata else None
)
for c in chunks
])
self.conn.commit()
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
"""Get a chunk by ID"""
row = self.conn.execute("""
SELECT * FROM chunks WHERE id = ?
""", (chunk_id,)).fetchone()
if not row:
return None
return self._row_to_chunk(row)
def search_vector(
self,
query_embedding: List[float],
user_id: Optional[str] = None,
scopes: List[str] = None,
limit: int = 10
) -> List[SearchResult]:
"""
Vector similarity search using in-memory cosine similarity
(sqlite-vec can be added later for better performance)
"""
if scopes is None:
scopes = ["shared"]
if user_id:
scopes.append("user")
# Build query
scope_placeholders = ','.join('?' * len(scopes))
params = scopes
if user_id:
query = f"""
SELECT * FROM chunks
WHERE scope IN ({scope_placeholders})
AND (scope = 'shared' OR user_id = ?)
AND embedding IS NOT NULL
"""
params.append(user_id)
else:
query = f"""
SELECT * FROM chunks
WHERE scope IN ({scope_placeholders})
AND embedding IS NOT NULL
"""
rows = self.conn.execute(query, params).fetchall()
# Calculate cosine similarity
results = []
for row in rows:
embedding = json.loads(row['embedding'])
similarity = self._cosine_similarity(query_embedding, embedding)
if similarity > 0:
results.append((similarity, row))
# Sort by similarity and limit
results.sort(key=lambda x: x[0], reverse=True)
results = results[:limit]
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=score,
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for score, row in results
]
def search_keyword(
self,
query: str,
user_id: Optional[str] = None,
scopes: List[str] = None,
limit: int = 10
) -> List[SearchResult]:
"""
Keyword search using FTS5 + LIKE fallback
Strategy:
1. Try FTS5 search first (good for English and word-based languages)
2. If no results and query contains CJK characters, use LIKE search
"""
if scopes is None:
scopes = ["shared"]
if user_id:
scopes.append("user")
# Try FTS5 search first
fts_results = self._search_fts5(query, user_id, scopes, limit)
if fts_results:
return fts_results
# Fallback to LIKE search for CJK characters
if MemoryStorage._contains_cjk(query):
return self._search_like(query, user_id, scopes, limit)
return []
def _search_fts5(
self,
query: str,
user_id: Optional[str],
scopes: List[str],
limit: int
) -> List[SearchResult]:
"""FTS5 full-text search"""
fts_query = self._build_fts_query(query)
if not fts_query:
return []
scope_placeholders = ','.join('?' * len(scopes))
params = [fts_query] + scopes
if user_id:
sql_query = f"""
SELECT chunks.*, bm25(chunks_fts) as rank
FROM chunks_fts
JOIN chunks ON chunks.id = chunks_fts.id
WHERE chunks_fts MATCH ?
AND chunks.scope IN ({scope_placeholders})
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
ORDER BY rank
LIMIT ?
"""
params.extend([user_id, limit])
else:
sql_query = f"""
SELECT chunks.*, bm25(chunks_fts) as rank
FROM chunks_fts
JOIN chunks ON chunks.id = chunks_fts.id
WHERE chunks_fts MATCH ?
AND chunks.scope IN ({scope_placeholders})
ORDER BY rank
LIMIT ?
"""
params.append(limit)
try:
rows = self.conn.execute(sql_query, params).fetchall()
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=self._bm25_rank_to_score(row['rank']),
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for row in rows
]
except Exception:
return []
def _search_like(
self,
query: str,
user_id: Optional[str],
scopes: List[str],
limit: int
) -> List[SearchResult]:
"""LIKE-based search for CJK characters"""
import re
# Extract CJK words (2+ characters)
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
if not cjk_words:
return []
scope_placeholders = ','.join('?' * len(scopes))
# Build LIKE conditions for each word
like_conditions = []
params = []
for word in cjk_words:
like_conditions.append("text LIKE ?")
params.append(f'%{word}%')
where_clause = ' OR '.join(like_conditions)
params.extend(scopes)
if user_id:
sql_query = f"""
SELECT * FROM chunks
WHERE ({where_clause})
AND scope IN ({scope_placeholders})
AND (scope = 'shared' OR user_id = ?)
LIMIT ?
"""
params.extend([user_id, limit])
else:
sql_query = f"""
SELECT * FROM chunks
WHERE ({where_clause})
AND scope IN ({scope_placeholders})
LIMIT ?
"""
params.append(limit)
try:
rows = self.conn.execute(sql_query, params).fetchall()
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=0.5, # Fixed score for LIKE search
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for row in rows
]
except Exception:
return []
def delete_by_path(self, path: str):
"""Delete all chunks from a file"""
self.conn.execute("""
DELETE FROM chunks WHERE path = ?
""", (path,))
self.conn.commit()
def get_file_hash(self, path: str) -> Optional[str]:
"""Get stored file hash"""
row = self.conn.execute("""
SELECT hash FROM files WHERE path = ?
""", (path,)).fetchone()
return row['hash'] if row else None
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
"""Update file metadata"""
self.conn.execute("""
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
""", (path, source, file_hash, mtime, size))
self.conn.commit()
def get_stats(self) -> Dict[str, int]:
"""Get storage statistics"""
chunks_count = self.conn.execute("""
SELECT COUNT(*) as cnt FROM chunks
""").fetchone()['cnt']
files_count = self.conn.execute("""
SELECT COUNT(*) as cnt FROM files
""").fetchone()['cnt']
return {
'chunks': chunks_count,
'files': files_count
}
def close(self):
"""Close database connection"""
if self.conn:
try:
self.conn.commit() # Ensure all changes are committed
self.conn.close()
self.conn = None # Mark as closed
except Exception as e:
print(f"⚠️ Error closing database connection: {e}")
def __del__(self):
"""Destructor to ensure connection is closed"""
try:
self.close()
except:
pass # Ignore errors during cleanup
# Helper methods
def _row_to_chunk(self, row) -> MemoryChunk:
"""Convert database row to MemoryChunk"""
return MemoryChunk(
id=row['id'],
user_id=row['user_id'],
scope=row['scope'],
source=row['source'],
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
text=row['text'],
embedding=json.loads(row['embedding']) if row['embedding'] else None,
hash=row['hash'],
metadata=json.loads(row['metadata']) if row['metadata'] else None
)
@staticmethod
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""Calculate cosine similarity between two vectors"""
if len(vec1) != len(vec2):
return 0.0
dot_product = sum(a * b for a, b in zip(vec1, vec2))
norm1 = sum(a * a for a in vec1) ** 0.5
norm2 = sum(b * b for b in vec2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
@staticmethod
def _contains_cjk(text: str) -> bool:
"""Check if text contains CJK (Chinese/Japanese/Korean) characters"""
import re
return bool(re.search(r'[\u4e00-\u9fff]', text))
@staticmethod
def _build_fts_query(raw_query: str) -> Optional[str]:
"""
Build FTS5 query from raw text
Works best for English and word-based languages.
For CJK characters, LIKE search will be used as fallback.
"""
import re
# Extract words (primarily English words and numbers)
tokens = re.findall(r'[A-Za-z0-9_]+', raw_query)
if not tokens:
return None
# Quote tokens for exact matching
quoted = [f'"{t}"' for t in tokens]
# Use OR for more flexible matching
return ' OR '.join(quoted)
@staticmethod
def _bm25_rank_to_score(rank: float) -> float:
"""Convert BM25 rank to 0-1 score"""
normalized = max(0, rank) if rank is not None else 999
return 1 / (1 + normalized)
@staticmethod
def _truncate_text(text: str, max_chars: int) -> str:
"""Truncate text to max characters"""
if len(text) <= max_chars:
return text
return text[:max_chars] + "..."
@staticmethod
def compute_hash(content: str) -> str:
"""Compute SHA256 hash of content"""
return hashlib.sha256(content.encode('utf-8')).hexdigest()

236
agent/memory/summarizer.py Normal file
View File

@@ -0,0 +1,236 @@
"""
Memory flush manager
Triggers memory flush before context compaction (similar to clawdbot)
"""
from typing import Optional, Callable, Any
from pathlib import Path
from datetime import datetime
class MemoryFlushManager:
"""
Manages memory flush operations before context compaction
Similar to clawdbot's memory flush mechanism:
- Triggers when context approaches token limit
- Runs a silent agent turn to write memories to disk
- Uses memory/YYYY-MM-DD.md for daily notes
- Uses MEMORY.md (workspace root) for long-term curated memories
"""
def __init__(
self,
workspace_dir: Path,
llm_model: Optional[Any] = None
):
"""
Initialize memory flush manager
Args:
workspace_dir: Workspace directory
llm_model: LLM model for agent execution (optional)
"""
self.workspace_dir = workspace_dir
self.llm_model = llm_model
self.memory_dir = workspace_dir / "memory"
self.memory_dir.mkdir(parents=True, exist_ok=True)
# Tracking
self.last_flush_token_count: Optional[int] = None
self.last_flush_timestamp: Optional[datetime] = None
def should_flush(
self,
current_tokens: int,
context_window: int,
reserve_tokens: int = 20000,
soft_threshold: int = 4000
) -> bool:
"""
Determine if memory flush should be triggered
Similar to clawdbot's shouldRunMemoryFlush logic:
threshold = contextWindow - reserveTokens - softThreshold
Args:
current_tokens: Current session token count
context_window: Model's context window size
reserve_tokens: Reserve tokens for compaction overhead
soft_threshold: Trigger flush N tokens before threshold
Returns:
True if flush should run
"""
if current_tokens <= 0:
return False
threshold = max(0, context_window - reserve_tokens - soft_threshold)
if threshold <= 0:
return False
# Check if we've crossed the threshold
if current_tokens < threshold:
return False
# Avoid duplicate flush in same compaction cycle
if self.last_flush_token_count is not None:
if current_tokens <= self.last_flush_token_count + soft_threshold:
return False
return True
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
"""
Get today's memory file path: memory/YYYY-MM-DD.md
Args:
user_id: Optional user ID for user-specific memory
Returns:
Path to today's memory file
"""
today = datetime.now().strftime("%Y-%m-%d")
if user_id:
user_dir = self.memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
return user_dir / f"{today}.md"
else:
return self.memory_dir / f"{today}.md"
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
"""
Get main memory file path: MEMORY.md (workspace root)
Args:
user_id: Optional user ID for user-specific memory
Returns:
Path to main memory file
"""
if user_id:
user_dir = self.memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
return user_dir / "MEMORY.md"
else:
# Return workspace root MEMORY.md
return Path(self.workspace_root) / "MEMORY.md"
def create_flush_prompt(self) -> str:
"""
Create prompt for memory flush turn
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_PROMPT
"""
today = datetime.now().strftime("%Y-%m-%d")
return (
f"Pre-compaction memory flush. "
f"Store durable memories now (use memory/{today}.md for daily notes; "
f"create memory/ if needed). "
f"If nothing to store, reply with NO_REPLY."
)
def create_flush_system_prompt(self) -> str:
"""
Create system prompt for memory flush turn
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_SYSTEM_PROMPT
"""
return (
"Pre-compaction memory flush turn. "
"The session is near auto-compaction; capture durable memories to disk. "
"You may reply, but usually NO_REPLY is correct."
)
async def execute_flush(
self,
agent_executor: Callable,
current_tokens: int,
user_id: Optional[str] = None,
**executor_kwargs
) -> bool:
"""
Execute memory flush by running a silent agent turn
Args:
agent_executor: Function to execute agent with prompt
current_tokens: Current token count
user_id: Optional user ID
**executor_kwargs: Additional kwargs for agent executor
Returns:
True if flush completed successfully
"""
try:
# Create flush prompts
prompt = self.create_flush_prompt()
system_prompt = self.create_flush_system_prompt()
# Execute agent turn (silent, no user-visible reply expected)
await agent_executor(
prompt=prompt,
system_prompt=system_prompt,
silent=True, # NO_REPLY expected
**executor_kwargs
)
# Track flush
self.last_flush_token_count = current_tokens
self.last_flush_timestamp = datetime.now()
return True
except Exception as e:
print(f"Memory flush failed: {e}")
return False
def get_status(self) -> dict:
"""Get memory flush status"""
return {
'last_flush_tokens': self.last_flush_token_count,
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
'today_file': str(self.get_today_memory_file()),
'main_file': str(self.get_main_memory_file())
}
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
"""
Create default memory files if they don't exist
Args:
workspace_dir: Workspace directory
user_id: Optional user ID for user-specific files
"""
memory_dir = workspace_dir / "memory"
memory_dir.mkdir(parents=True, exist_ok=True)
# Create main MEMORY.md in workspace root
if user_id:
user_dir = memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
main_memory = user_dir / "MEMORY.md"
else:
main_memory = Path(workspace_root) / "MEMORY.md"
if not main_memory.exists():
# Create empty file or with minimal structure (no obvious "Memory" header)
# Following clawdbot's approach: memories should blend naturally into context
main_memory.write_text("")
# Create today's memory file
today = datetime.now().strftime("%Y-%m-%d")
if user_id:
user_dir = memory_dir / "users" / user_id
today_memory = user_dir / f"{today}.md"
else:
today_memory = memory_dir / f"{today}.md"
if not today_memory.exists():
today_memory.write_text(
f"# Daily Memory: {today}\n\n"
f"Day-to-day notes and running context.\n\n"
)

13
agent/prompt/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""
Agent Prompt Module - 系统提示词构建模块
"""
from .builder import PromptBuilder, build_agent_system_prompt
from .workspace import ensure_workspace, load_context_files
__all__ = [
'PromptBuilder',
'build_agent_system_prompt',
'ensure_workspace',
'load_context_files',
]

445
agent/prompt/builder.py Normal file
View File

@@ -0,0 +1,445 @@
"""
System Prompt Builder - 系统提示词构建器
参考 clawdbot 的 system-prompt.ts实现中文版的模块化提示词构建
"""
import os
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
from common.log import logger
@dataclass
class ContextFile:
"""上下文文件"""
path: str
content: str
class PromptBuilder:
"""提示词构建器"""
def __init__(self, workspace_dir: str, language: str = "zh"):
"""
初始化提示词构建器
Args:
workspace_dir: 工作空间目录
language: 语言 ("zh""en")
"""
self.workspace_dir = workspace_dir
self.language = language
def build(
self,
base_persona: Optional[str] = None,
user_identity: Optional[Dict[str, str]] = None,
tools: Optional[List[Any]] = None,
context_files: Optional[List[ContextFile]] = None,
skill_manager: Any = None,
memory_manager: Any = None,
runtime_info: Optional[Dict[str, Any]] = None,
**kwargs
) -> str:
"""
构建完整的系统提示词
Args:
base_persona: 基础人格描述会被context_files中的SOUL.md覆盖
user_identity: 用户身份信息
tools: 工具列表
context_files: 上下文文件列表SOUL.md, USER.md, README.md等
skill_manager: 技能管理器
memory_manager: 记忆管理器
runtime_info: 运行时信息
**kwargs: 其他参数
Returns:
完整的系统提示词
"""
return build_agent_system_prompt(
workspace_dir=self.workspace_dir,
language=self.language,
base_persona=base_persona,
user_identity=user_identity,
tools=tools,
context_files=context_files,
skill_manager=skill_manager,
memory_manager=memory_manager,
runtime_info=runtime_info,
**kwargs
)
def build_agent_system_prompt(
workspace_dir: str,
language: str = "zh",
base_persona: Optional[str] = None,
user_identity: Optional[Dict[str, str]] = None,
tools: Optional[List[Any]] = None,
context_files: Optional[List[ContextFile]] = None,
skill_manager: Any = None,
memory_manager: Any = None,
runtime_info: Optional[Dict[str, Any]] = None,
**kwargs
) -> str:
"""
构建Agent系统提示词精简版中文
包含的sections:
1. 基础身份
2. 工具说明
3. 技能系统
4. 记忆系统
5. 用户身份
6. 文档路径
7. 工作空间
8. 项目上下文文件
Args:
workspace_dir: 工作空间目录
language: 语言 ("zh""en")
base_persona: 基础人格描述
user_identity: 用户身份信息
tools: 工具列表
context_files: 上下文文件列表
skill_manager: 技能管理器
memory_manager: 记忆管理器
runtime_info: 运行时信息
**kwargs: 其他参数
Returns:
完整的系统提示词
"""
sections = []
# 1. 基础身份
sections.extend(_build_identity_section(base_persona, language))
# 2. 工具说明
if tools:
sections.extend(_build_tooling_section(tools, language))
# 3. 技能系统
if skill_manager:
sections.extend(_build_skills_section(skill_manager, tools, language))
# 4. 记忆系统
if memory_manager:
sections.extend(_build_memory_section(memory_manager, tools, language))
# 5. 用户身份
if user_identity:
sections.extend(_build_user_identity_section(user_identity, language))
# 6. 工作空间
sections.extend(_build_workspace_section(workspace_dir, language))
# 7. 项目上下文文件SOUL.md, USER.md等
if context_files:
sections.extend(_build_context_files_section(context_files, language))
# 8. 运行时信息(如果有)
if runtime_info:
sections.extend(_build_runtime_section(runtime_info, language))
return "\n".join(sections)
def _build_identity_section(base_persona: Optional[str], language: str) -> List[str]:
"""构建基础身份section - 不再需要身份由SOUL.md定义"""
# 不再生成基础身份section完全由SOUL.md定义
return []
def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
"""构建工具说明section"""
lines = [
"## 工具系统",
"",
"你可以使用以下工具来完成任务。工具名称是大小写敏感的,请严格按照列表中的名称调用。",
"",
"### 可用工具",
"",
]
# 工具分类和排序
tool_categories = {
"文件操作": ["read", "write", "edit", "ls", "grep", "find"],
"命令执行": ["bash", "terminal"],
"网络搜索": ["web_search", "web_fetch", "browser"],
"记忆系统": ["memory_search", "memory_get"],
"其他": []
}
# 构建工具映射
tool_map = {}
tool_descriptions = {
"read": "读取文件内容",
"write": "创建或覆盖文件",
"edit": "精确编辑文件内容",
"ls": "列出目录内容",
"grep": "在文件中搜索内容",
"find": "按照模式查找文件",
"bash": "执行shell命令",
"terminal": "管理后台进程",
"web_search": "网络搜索(使用搜索引擎)",
"web_fetch": "获取URL内容",
"browser": "控制浏览器",
"memory_search": "搜索记忆文件",
"memory_get": "获取记忆文件内容",
"calculator": "计算器",
"current_time": "获取当前时间",
}
for tool in tools:
tool_name = tool.name if hasattr(tool, 'name') else str(tool)
tool_desc = tool.description if hasattr(tool, 'description') else tool_descriptions.get(tool_name, "")
tool_map[tool_name] = tool_desc
# 按分类添加工具
for category, tool_names in tool_categories.items():
category_tools = [(name, tool_map.get(name, "")) for name in tool_names if name in tool_map]
if category_tools:
lines.append(f"**{category}**:")
for name, desc in category_tools:
if desc:
lines.append(f"- `{name}`: {desc}")
else:
lines.append(f"- `{name}`")
del tool_map[name] # 移除已添加的工具
lines.append("")
# 添加其他未分类的工具
if tool_map:
lines.append("**其他工具**:")
for name, desc in sorted(tool_map.items()):
if desc:
lines.append(f"- `{name}`: {desc}")
else:
lines.append(f"- `{name}`")
lines.append("")
# 工具使用指南
lines.extend([
"### 工具调用风格",
"",
"**默认规则**: 对于常规、低风险的工具调用,无需叙述,直接调用即可。",
"",
"**需要叙述的情况**:",
"- 多步骤、复杂的任务",
"- 敏感操作(如删除文件)",
"- 用户明确要求解释过程",
"",
"**完成后**: 工具调用完成后,给用户一个简短、自然的确认或回复,不要直接结束对话。",
"",
])
return lines
def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
"""构建技能系统section"""
if not skill_manager:
return []
# 获取read工具名称
read_tool_name = "read"
if tools:
for tool in tools:
tool_name = tool.name if hasattr(tool, 'name') else str(tool)
if tool_name.lower() == "read":
read_tool_name = tool_name
break
lines = [
"## 技能系统",
"",
"在回复之前:扫描下方 <available_skills> 中的 <description> 条目。",
"",
f"- 如果恰好有一个技能明确适用:使用 `{read_tool_name}` 工具读取其 <location> 路径下的 SKILL.md 文件,然后遵循它。",
"- 如果多个技能都适用:选择最具体的一个,然后读取并遵循。",
"- 如果没有明确适用的:不要读取任何 SKILL.md。",
"",
"**约束**: 永远不要一次性读取多个技能;只在选择后再读取。",
"",
]
# 添加技能列表通过skill_manager获取
try:
skills_prompt = skill_manager.build_skills_prompt()
if skills_prompt:
lines.append(skills_prompt.strip())
lines.append("")
except Exception as e:
logger.warning(f"Failed to build skills prompt: {e}")
return lines
def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
"""构建记忆系统section"""
if not memory_manager:
return []
# 检查是否有memory工具
has_memory_tools = False
if tools:
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
has_memory_tools = any(name in ['memory_search', 'memory_get'] for name in tool_names)
if not has_memory_tools:
return []
lines = [
"## 记忆系统",
"",
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
"",
"1. 使用 `memory_search` 在 MEMORY.md 和 memory/*.md 中搜索",
"2. 然后使用 `memory_get` 只拉取需要的行",
"3. 如果搜索后仍然信心不足,告诉用户你已经检查过了",
"",
"**记忆文件结构**:",
"- `MEMORY.md`: 长期记忆,包含重要的背景信息",
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的对话和事件",
"",
"**使用原则**:",
"- 自然使用记忆,就像你本来就知道",
"- 不要主动提起或列举记忆,除非用户明确询问",
"",
]
return lines
def _build_user_identity_section(user_identity: Dict[str, str], language: str) -> List[str]:
"""构建用户身份section"""
if not user_identity:
return []
lines = [
"## 用户身份",
"",
]
if user_identity.get("name"):
lines.append(f"**用户姓名**: {user_identity['name']}")
if user_identity.get("nickname"):
lines.append(f"**称呼**: {user_identity['nickname']}")
if user_identity.get("timezone"):
lines.append(f"**时区**: {user_identity['timezone']}")
if user_identity.get("notes"):
lines.append(f"**备注**: {user_identity['notes']}")
lines.append("")
return lines
def _build_docs_section(workspace_dir: str, language: str) -> List[str]:
"""构建文档路径section - 已移除,不再需要"""
# 不再生成文档section
return []
def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
"""构建工作空间section"""
lines = [
"## 工作空间",
"",
f"你的工作目录是: `{workspace_dir}`",
"",
"除非用户明确指示,否则将此目录视为文件操作的全局工作空间。",
"",
"**重要说明 - 文件已自动加载**:",
"",
"以下文件在会话启动时**已经自动加载**到系统提示词的「项目上下文」section 中,你**无需再用 read 工具读取它们**",
"",
"- ✅ `SOUL.md`: 已加载 - Agent的人格设定",
"- ✅ `USER.md`: 已加载 - 用户的身份信息",
"- ✅ `AGENTS.md`: 已加载 - 工作空间使用指南",
"",
"**首次对话**:",
"",
"如果这是你与用户的首次对话,并且你的人格设定和用户信息还是空白或初始状态:",
"",
"1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待",
"2. **简短打招呼后,分点询问三个核心问题**",
" - 你希望我叫什么名字?",
" - 你希望我怎么称呼你?",
" - 你希望我们是什么样的交流风格?(这里需要举例,如:专业严谨、轻松幽默、温暖友好等)",
"3. **语言风格**:温暖但不过度诗意,带点科技感,保持清晰",
"4. **问题格式**:用分点或换行,让问题清晰易读;前两个问题不需要额外说明,只有交流风格需要举例",
"5. 收到回复后,用 `write` 工具保存到 USER.md 和 SOUL.md",
"",
"**重要**: ",
"- 在所有对话中,无需提及技术细节(如 SOUL.md、USER.md 等文件名,工具名称,配置等),除非用户明确询问。用自然表达如「我已记住」而非「已更新 SOUL.md」",
"- 不要问太多其他信息(职业、时区等可以后续自然了解)",
"- 保持简洁,避免过度抒情",
"",
]
return lines
def _build_context_files_section(context_files: List[ContextFile], language: str) -> List[str]:
"""构建项目上下文文件section"""
if not context_files:
return []
# 检查是否有SOUL.md
has_soul = any(
f.path.lower().endswith('soul.md') or 'soul.md' in f.path.lower()
for f in context_files
)
lines = [
"# 项目上下文",
"",
"以下项目上下文文件已被加载:",
"",
]
if has_soul:
lines.append("如果存在 `SOUL.md`,请体现其中定义的人格和语气。避免僵硬、模板化的回复;遵循其指导,除非有更高优先级的指令覆盖它。")
lines.append("")
# 添加每个文件的内容
for file in context_files:
lines.append(f"## {file.path}")
lines.append("")
lines.append(file.content)
lines.append("")
return lines
def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[str]:
"""构建运行时信息section"""
if not runtime_info:
return []
# Only include if there's actual runtime info to display
runtime_parts = []
if runtime_info.get("model"):
runtime_parts.append(f"模型={runtime_info['model']}")
if runtime_info.get("workspace"):
runtime_parts.append(f"工作空间={runtime_info['workspace']}")
# Only add channel if it's not the default "web"
if runtime_info.get("channel") and runtime_info.get("channel") != "web":
runtime_parts.append(f"渠道={runtime_info['channel']}")
if not runtime_parts:
return []
lines = [
"## 运行时信息",
"",
"运行时: " + " | ".join(runtime_parts),
""
]
return lines

314
agent/prompt/workspace.py Normal file
View File

@@ -0,0 +1,314 @@
"""
Workspace Management - 工作空间管理模块
负责初始化工作空间、创建模板文件、加载上下文文件
"""
import os
from typing import List, Optional, Dict
from dataclasses import dataclass
from common.log import logger
from .builder import ContextFile
# 默认文件名常量
DEFAULT_SOUL_FILENAME = "SOUL.md"
DEFAULT_USER_FILENAME = "USER.md"
DEFAULT_AGENTS_FILENAME = "AGENTS.md"
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
@dataclass
class WorkspaceFiles:
"""工作空间文件路径"""
soul_path: str
user_path: str
agents_path: str
memory_path: str
memory_dir: str
def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> WorkspaceFiles:
"""
确保工作空间存在,并创建必要的模板文件
Args:
workspace_dir: 工作空间目录路径
create_templates: 是否创建模板文件(首次运行时)
Returns:
WorkspaceFiles对象包含所有文件路径
"""
# 确保目录存在
os.makedirs(workspace_dir, exist_ok=True)
# 定义文件路径
soul_path = os.path.join(workspace_dir, DEFAULT_SOUL_FILENAME)
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
agents_path = os.path.join(workspace_dir, DEFAULT_AGENTS_FILENAME)
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md 在根目录
memory_dir = os.path.join(workspace_dir, "memory") # 每日记忆子目录
# 创建memory子目录
os.makedirs(memory_dir, exist_ok=True)
# 如果需要,创建模板文件
if create_templates:
_create_template_if_missing(soul_path, _get_soul_template())
_create_template_if_missing(user_path, _get_user_template())
_create_template_if_missing(agents_path, _get_agents_template())
_create_template_if_missing(memory_path, _get_memory_template())
logger.info(f"[Workspace] Initialized workspace at: {workspace_dir}")
return WorkspaceFiles(
soul_path=soul_path,
user_path=user_path,
agents_path=agents_path,
memory_path=memory_path,
memory_dir=memory_dir
)
def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] = None) -> List[ContextFile]:
"""
加载工作空间的上下文文件
Args:
workspace_dir: 工作空间目录
files_to_load: 要加载的文件列表相对路径如果为None则加载所有标准文件
Returns:
ContextFile对象列表
"""
if files_to_load is None:
# 默认加载的文件(按优先级排序)
files_to_load = [
DEFAULT_SOUL_FILENAME,
DEFAULT_USER_FILENAME,
DEFAULT_AGENTS_FILENAME,
]
context_files = []
for filename in files_to_load:
filepath = os.path.join(workspace_dir, filename)
if not os.path.exists(filepath):
continue
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 跳过空文件或只包含模板占位符的文件
if not content or _is_template_placeholder(content):
continue
context_files.append(ContextFile(
path=filename,
content=content
))
logger.debug(f"[Workspace] Loaded context file: {filename}")
except Exception as e:
logger.warning(f"[Workspace] Failed to load {filename}: {e}")
return context_files
def _create_template_if_missing(filepath: str, template_content: str):
"""如果文件不存在,创建模板文件"""
if not os.path.exists(filepath):
try:
with open(filepath, 'w', encoding='utf-8') as f:
f.write(template_content)
logger.debug(f"[Workspace] Created template: {os.path.basename(filepath)}")
except Exception as e:
logger.error(f"[Workspace] Failed to create template {filepath}: {e}")
def _is_template_placeholder(content: str) -> bool:
"""检查内容是否为模板占位符"""
# 常见的占位符模式
placeholders = [
"*(填写",
"*(在首次对话时填写",
"*(可选)",
"*(根据需要添加",
]
lines = content.split('\n')
non_empty_lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
# 如果没有实际内容(只有标题和占位符)
if len(non_empty_lines) <= 3:
for placeholder in placeholders:
if any(placeholder in line for line in non_empty_lines):
return True
return False
# ============= 模板内容 =============
def _get_soul_template() -> str:
"""Agent人格设定模板"""
return """# SOUL.md - 我是谁?
*在首次对话时与用户一起填写这个文件,定义你的身份和性格。*
## 基本信息
- **名字**: *(在首次对话时填写,可以是用户给你起的名字)*
- **角色**: *(AI助理、智能管家、技术顾问等)*
- **性格**: *(友好、专业、幽默、严谨等)*
## 交流风格
*(描述你如何与用户交流:)*
- 使用什么样的语言风格?(正式/轻松/幽默)
- 回复长度偏好?(简洁/详细)
- 是否使用表情符号?
## 核心能力
*(你擅长什么?)*
- 文件管理和代码编辑
- 网络搜索和信息查询
- 记忆管理和上下文理解
- 任务规划和执行
## 行为准则
*(你遵循的基本原则:)*
1. 始终在执行破坏性操作前确认
2. 优先使用工具而不是猜测
3. 主动记录重要信息到记忆文件
4. 定期整理和总结对话内容
---
**注意**: 这不仅仅是元数据,这是你真正的灵魂。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
"""
def _get_user_template() -> str:
"""用户身份信息模板"""
return """# USER.md - 用户基本信息
*这个文件只存放不会变的基本身份信息。爱好、偏好、计划等动态信息请写入 MEMORY.md。*
## 基本信息
- **姓名**: *(在首次对话时询问)*
- **称呼**: *(用户希望被如何称呼)*
- **职业**: *(可选)*
- **时区**: *(例如: Asia/Shanghai)*
## 联系方式
- **微信**:
- **邮箱**:
- **其他**:
## 重要日期
- **生日**:
- **纪念日**:
---
**注意**: 这个文件存放静态的身份信息
"""
def _get_agents_template() -> str:
"""工作空间指南模板"""
return """# AGENTS.md - 工作空间指南
这个文件夹是你的家。好好对待它。
## 系统自动加载
以下文件在每次会话启动时**已经自动加载**到系统提示词中,你无需再次读取:
- ✅ `SOUL.md` - 你的人格设定(已加载)
- ✅ `USER.md` - 用户信息(已加载)
- ✅ `AGENTS.md` - 本文件(已加载)
## 按需读取
以下文件**不会自动加载**,需要时使用相应工具读取:
- 📝 `memory/YYYY-MM-DD.md` - 每日记忆(用 memory_search 检索)
- 🧠 `MEMORY.md` - 长期记忆(用 memory_search 检索)
## 记忆系统
你每次会话都是全新的。这些文件是你的连续性:
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
- 原始的对话日志
- 记录当天发生的事情
- 如果 `memory/` 目录不存在,创建它
### 🧠 长期记忆:`MEMORY.md`
- 你精选的记忆,就像人类的长期记忆
- **仅在主会话中加载**(与用户的直接聊天)
- **不要在共享上下文中加载**(群聊、与其他人的会话)
- 这是为了**安全** - 包含不应泄露给陌生人的个人上下文
- 记录重要事件、想法、决定、观点、经验教训
- 这是你精选的记忆 - 精华,而不是原始日志
- 用 `edit` 工具追加新的记忆内容
### 📝 写下来 - 不要"记在心里"
- **记忆是有限的** - 如果你想记住某事,写入文件
- "记在心里"不会在会话重启后保留,文件才会
- 当有人说"记住这个" → 更新 `MEMORY.md` 或 `memory/YYYY-MM-DD.md`
- 当你学到教训 → 更新 AGENTS.md 或相关技能
- 当你犯错 → 记录下来,这样未来的你不会重复
- **文字 > 大脑** 📝
### 存储规则
当用户分享信息时,根据类型选择存储位置:
1. **静态身份 → USER.md**(仅限:姓名、职业、时区、联系方式、生日)
2. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项)
3. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
**重要**:
- 爱好(唱歌、篮球等)→ MEMORY.md不是 USER.md
- 近期计划(下周要做什么)→ MEMORY.md不是 USER.md
- USER.md 只存放不会变的基本信息
## 安全
- 永远不要泄露私人数据
- 不要在未经询问的情况下运行破坏性命令
- 当有疑问时,先问
## 工具使用
技能提供你的工具。当你需要一个时,查看它的 `SKILL.md`。
## 让它成为你的
这只是一个起点。随着你弄清楚什么有效,添加你自己的约定、风格和规则。
"""
def _get_memory_template() -> str:
"""长期记忆模板 - 创建一个空文件,由 Agent 自己填充"""
return """# MEMORY.md - 长期记忆
*这是你的长期记忆文件。记录重要的事件、决策、偏好、学到的教训。*
---
"""

View File

@@ -0,0 +1,20 @@
from .agent import Agent
from .agent_stream import AgentStreamExecutor
from .task import Task, TaskType, TaskStatus
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
from .models import LLMModel, LLMRequest, ModelFactory
__all__ = [
'Agent',
'AgentStreamExecutor',
'Task',
'TaskType',
'TaskStatus',
'AgentResult',
'AgentAction',
'AgentActionType',
'ToolResult',
'LLMModel',
'LLMRequest',
'ModelFactory'
]

371
agent/protocol/agent.py Normal file
View File

@@ -0,0 +1,371 @@
import json
import time
from common.log import logger
from agent.protocol.models import LLMRequest, LLMModel
from agent.protocol.agent_stream import AgentStreamExecutor
from agent.protocol.result import AgentAction, AgentActionType, ToolResult, AgentResult
from agent.tools.base_tool import BaseTool, ToolStage
class Agent:
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
context_reserve_tokens=None, memory_manager=None, name: str = None,
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True):
"""
Initialize the Agent with system prompt, model, description.
:param system_prompt: The system prompt for the agent.
:param description: A description of the agent.
:param model: An instance of LLMModel to be used by the agent.
:param tools: Optional list of tools for the agent to use.
:param output_mode: Control how execution progress is displayed:
"print" for console output or "logger" for using logger
:param max_steps: Maximum number of steps the agent can take (default: 100)
:param max_context_tokens: Maximum tokens to keep in context (default: None, auto-calculated based on model)
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
:param memory_manager: Optional MemoryManager instance for memory operations
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
:param workspace_dir: Optional workspace directory for workspace-specific skills
:param skill_manager: Optional SkillManager instance (will be created if None and enable_skills=True)
:param enable_skills: Whether to enable skills support (default: True)
"""
self.name = name or "Agent"
self.system_prompt = system_prompt
self.model: LLMModel = model # Instance of LLMModel
self.description = description
self.tools: list = []
self.max_steps = max_steps # max tool-call steps, default 100
self.max_context_tokens = max_context_tokens # max tokens in context
self.context_reserve_tokens = context_reserve_tokens # reserve tokens for new requests
self.captured_actions = [] # Initialize captured actions list
self.output_mode = output_mode
self.last_usage = None # Store last API response usage info
self.messages = [] # Unified message history for stream mode
self.memory_manager = memory_manager # Memory manager for auto memory flush
self.workspace_dir = workspace_dir # Workspace directory
self.enable_skills = enable_skills # Skills enabled flag
# Initialize skill manager
self.skill_manager = None
if enable_skills:
if skill_manager:
self.skill_manager = skill_manager
else:
# Auto-create skill manager
try:
from agent.skills import SkillManager
self.skill_manager = SkillManager(workspace_dir=workspace_dir)
logger.info(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
except Exception as e:
logger.warning(f"Failed to initialize SkillManager: {e}")
if tools:
for tool in tools:
self.add_tool(tool)
def add_tool(self, tool: BaseTool):
"""
Add a tool to the agent.
:param tool: The tool to add (either a tool instance or a tool name)
"""
# If tool is already an instance, use it directly
tool.model = self.model
self.tools.append(tool)
def get_skills_prompt(self, skill_filter=None) -> str:
"""
Get the skills prompt to append to system prompt.
:param skill_filter: Optional list of skill names to include
:return: Formatted skills prompt or empty string
"""
if not self.skill_manager:
return ""
try:
return self.skill_manager.build_skills_prompt(skill_filter=skill_filter)
except Exception as e:
logger.warning(f"Failed to build skills prompt: {e}")
return ""
def get_full_system_prompt(self, skill_filter=None) -> str:
"""
Get the full system prompt including skills.
:param skill_filter: Optional list of skill names to include
:return: Complete system prompt with skills appended
"""
base_prompt = self.system_prompt
skills_prompt = self.get_skills_prompt(skill_filter=skill_filter)
if skills_prompt:
return base_prompt + "\n" + skills_prompt
return base_prompt
def refresh_skills(self):
"""Refresh the loaded skills."""
if self.skill_manager:
self.skill_manager.refresh_skills()
logger.info(f"Refreshed skills: {len(self.skill_manager.skills)} skills loaded")
def list_skills(self):
"""
List all loaded skills.
:return: List of skill entries or empty list
"""
if not self.skill_manager:
return []
return self.skill_manager.list_skills()
def _get_model_context_window(self) -> int:
"""
Get the model's context window size in tokens.
Auto-detect based on model name.
Model context windows:
- Claude 3.5/3.7 Sonnet: 200K tokens
- Claude 3 Opus: 200K tokens
- GPT-4 Turbo/128K: 128K tokens
- GPT-4: 8K-32K tokens
- GPT-3.5: 16K tokens
- DeepSeek: 64K tokens
:return: Context window size in tokens
"""
if self.model and hasattr(self.model, 'model'):
model_name = self.model.model.lower()
# Claude models - 200K context
if 'claude-3' in model_name or 'claude-sonnet' in model_name:
return 200000
# GPT-4 models
elif 'gpt-4' in model_name:
if 'turbo' in model_name or '128k' in model_name:
return 128000
elif '32k' in model_name:
return 32000
else:
return 8000
# GPT-3.5
elif 'gpt-3.5' in model_name:
if '16k' in model_name:
return 16000
else:
return 4000
# DeepSeek
elif 'deepseek' in model_name:
return 64000
# Gemini models
elif 'gemini' in model_name:
if '2.0' in model_name or 'exp' in model_name:
return 2000000 # Gemini 2.0: 2M tokens
else:
return 1000000 # Gemini 1.5: 1M tokens
# Default conservative value
return 128000
def _get_context_reserve_tokens(self) -> int:
"""
Get the number of tokens to reserve for new requests.
This prevents context overflow by keeping a buffer.
:return: Number of tokens to reserve
"""
if self.context_reserve_tokens is not None:
return self.context_reserve_tokens
# Reserve ~10% of context window, with min 10K and max 200K
context_window = self._get_model_context_window()
reserve = int(context_window * 0.1)
return max(10000, min(200000, reserve))
def _estimate_message_tokens(self, message: dict) -> int:
"""
Estimate token count for a message using chars/4 heuristic.
This is a conservative estimate (tends to overestimate).
:param message: Message dict with 'role' and 'content'
:return: Estimated token count
"""
content = message.get('content', '')
if isinstance(content, str):
return max(1, len(content) // 4)
elif isinstance(content, list):
# Handle multi-part content (text + images)
total_chars = 0
for part in content:
if isinstance(part, dict) and part.get('type') == 'text':
total_chars += len(part.get('text', ''))
elif isinstance(part, dict) and part.get('type') == 'image':
# Estimate images as ~1200 tokens
total_chars += 4800
return max(1, total_chars // 4)
return 1
def _find_tool(self, tool_name: str):
"""Find and return a tool with the specified name"""
for tool in self.tools:
if tool.name == tool_name:
# Only pre-process stage tools can be actively called
if tool.stage == ToolStage.PRE_PROCESS:
tool.model = self.model
tool.context = self # Set tool context
return tool
else:
# If it's a post-process tool, return None to prevent direct calling
logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.")
return None
return None
# output function based on mode
def output(self, message="", end="\n"):
if self.output_mode == "print":
print(message, end=end)
elif message:
logger.info(message)
def _execute_post_process_tools(self):
"""Execute all post-process stage tools"""
# Get all post-process stage tools
post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS]
# Execute each tool
for tool in post_process_tools:
# Set tool context
tool.context = self
# Record start time for execution timing
start_time = time.time()
# Execute tool (with empty parameters, tool will extract needed info from context)
result = tool.execute({})
# Calculate execution time
execution_time = time.time() - start_time
# Capture tool use for tracking
self.capture_tool_use(
tool_name=tool.name,
input_params={}, # Post-process tools typically don't take parameters
output=result.result,
status=result.status,
error_message=str(result.result) if result.status == "error" else None,
execution_time=execution_time
)
# Log result
if result.status == "success":
# Print tool execution result in the desired format
self.output(f"\n🛠️ {tool.name}: {json.dumps(result.result)}")
else:
# Print failure in print mode
self.output(f"\n🛠️ {tool.name}: {json.dumps({'status': 'error', 'message': str(result.result)})}")
def capture_tool_use(self, tool_name, input_params, output, status, thought=None, error_message=None,
execution_time=0.0):
"""
Capture a tool use action.
:param thought: thought content
:param tool_name: Name of the tool used
:param input_params: Parameters passed to the tool
:param output: Output from the tool
:param status: Status of the tool execution
:param error_message: Error message if the tool execution failed
:param execution_time: Time taken to execute the tool
"""
tool_result = ToolResult(
tool_name=tool_name,
input_params=input_params,
output=output,
status=status,
error_message=error_message,
execution_time=execution_time
)
action = AgentAction(
agent_id=self.id if hasattr(self, 'id') else str(id(self)),
agent_name=self.name,
action_type=AgentActionType.TOOL_USE,
tool_result=tool_result,
thought=thought
)
self.captured_actions.append(action)
return action
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False, skill_filter=None) -> str:
"""
Execute single agent task with streaming (based on tool-call)
This method supports:
- Streaming output
- Multi-turn reasoning based on tool-call
- Event callbacks
- Persistent conversation history across calls
Args:
user_message: User message
on_event: Event callback function callback(event: dict)
event = {"type": str, "timestamp": float, "data": dict}
clear_history: If True, clear conversation history before this call (default: False)
skill_filter: Optional list of skill names to include in this run
Returns:
Final response text
Example:
# Multi-turn conversation with memory
response1 = agent.run_stream("My name is Alice")
response2 = agent.run_stream("What's my name?") # Will remember Alice
# Single-turn without memory
response = agent.run_stream("Hello", clear_history=True)
"""
# Clear history if requested
if clear_history:
self.messages = []
# Get model to use
if not self.model:
raise ValueError("No model available for agent")
# Get full system prompt with skills
full_system_prompt = self.get_full_system_prompt(skill_filter=skill_filter)
# Create stream executor with agent's message history
executor = AgentStreamExecutor(
agent=self,
model=self.model,
system_prompt=full_system_prompt,
tools=self.tools,
max_turns=self.max_steps,
on_event=on_event,
messages=self.messages # Pass agent's message history
)
# Execute
response = executor.run_stream(user_message)
# Update agent's message history from executor
self.messages = executor.messages
# Execute all post-process tools
self._execute_post_process_tools()
return response
def clear_history(self):
"""Clear conversation history and captured actions"""
self.messages = []
self.captured_actions = []

View File

@@ -0,0 +1,478 @@
"""
Agent Stream Execution Module - Multi-turn reasoning based on tool-call
Provides streaming output, event system, and complete tool-call loop
"""
import json
import time
from typing import List, Dict, Any, Optional, Callable
from common.log import logger
from agent.protocol.models import LLMRequest, LLMModel
from agent.tools.base_tool import BaseTool, ToolResult
class AgentStreamExecutor:
"""
Agent Stream Executor
Handles multi-turn reasoning loop based on tool-call:
1. LLM generates response (may include tool calls)
2. Execute tools
3. Return results to LLM
4. Repeat until no more tool calls
"""
def __init__(
self,
agent, # Agent instance
model: LLMModel,
system_prompt: str,
tools: List[BaseTool],
max_turns: int = 50,
on_event: Optional[Callable] = None,
messages: Optional[List[Dict]] = None
):
"""
Initialize stream executor
Args:
agent: Agent instance (for accessing context)
model: LLM model
system_prompt: System prompt
tools: List of available tools
max_turns: Maximum number of turns
on_event: Event callback function
messages: Optional existing message history (for persistent conversations)
"""
self.agent = agent
self.model = model
self.system_prompt = system_prompt
# Convert tools list to dict
self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools
self.max_turns = max_turns
self.on_event = on_event
# Message history - use provided messages or create new list
self.messages = messages if messages is not None else []
def _emit_event(self, event_type: str, data: dict = None):
"""Emit event"""
if self.on_event:
try:
self.on_event({
"type": event_type,
"timestamp": time.time(),
"data": data or {}
})
except Exception as e:
logger.error(f"Event callback error: {e}")
def run_stream(self, user_message: str) -> str:
"""
Execute streaming reasoning loop
Args:
user_message: User message
Returns:
Final response text
"""
# Log user message
logger.info(f"\n{'='*50}")
logger.info(f"👤 用户: {user_message}")
logger.info(f"{'='*50}")
# Add user message (Claude format - use content blocks for consistency)
self.messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": user_message
}
]
})
self._emit_event("agent_start")
final_response = ""
turn = 0
try:
while turn < self.max_turns:
turn += 1
logger.info(f"\n{'='*50}{turn}{'='*50}")
self._emit_event("turn_start", {"turn": turn})
# Check if memory flush is needed (before calling LLM)
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
usage = self.agent.last_usage
if usage and 'input_tokens' in usage:
current_tokens = usage.get('input_tokens', 0)
context_window = self.agent._get_model_context_window()
# Use configured reserve_tokens or calculate based on context window
reserve_tokens = self.agent._get_context_reserve_tokens()
# Use smaller soft_threshold to trigger flush earlier (e.g., at 50K tokens)
soft_threshold = 10000 # Trigger 10K tokens before limit
if self.agent.memory_manager.should_flush_memory(
current_tokens=current_tokens,
context_window=context_window,
reserve_tokens=reserve_tokens,
soft_threshold=soft_threshold
):
self._emit_event("memory_flush_start", {
"current_tokens": current_tokens,
"threshold": context_window - reserve_tokens - soft_threshold
})
# TODO: Execute memory flush in background
# This would require async support
logger.info(f"Memory flush recommended at {current_tokens} tokens")
# Call LLM
assistant_msg, tool_calls = self._call_llm_stream()
final_response = assistant_msg
# No tool calls, end loop
if not tool_calls:
if assistant_msg:
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
logger.info(f"✅ 完成 (无工具调用)")
self._emit_event("turn_end", {
"turn": turn,
"has_tool_calls": False
})
break
# Log tool calls in compact format
tool_names = [tc['name'] for tc in tool_calls]
logger.info(f"🔧 调用工具: {', '.join(tool_names)}")
# Execute tools
tool_results = []
tool_result_blocks = []
for tool_call in tool_calls:
result = self._execute_tool(tool_call)
tool_results.append(result)
# Log tool result in compact format
status_emoji = "" if result.get("status") == "success" else ""
result_data = result.get('result', '')
# Format result string with proper Chinese character support
if isinstance(result_data, (dict, list)):
result_str = json.dumps(result_data, ensure_ascii=False)
else:
result_str = str(result_data)
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
# Build tool result block (Claude format)
# Content should be a string representation of the result
result_content = json.dumps(result, ensure_ascii=False) if not isinstance(result, str) else result
tool_result_blocks.append({
"type": "tool_result",
"tool_use_id": tool_call["id"],
"content": result_content
})
# Add tool results to message history as user message (Claude format)
self.messages.append({
"role": "user",
"content": tool_result_blocks
})
self._emit_event("turn_end", {
"turn": turn,
"has_tool_calls": True,
"tool_count": len(tool_calls)
})
if turn >= self.max_turns:
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
except Exception as e:
logger.error(f"❌ Agent执行错误: {e}")
self._emit_event("error", {"error": str(e)})
raise
finally:
logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n")
self._emit_event("agent_end", {"final_response": final_response})
return final_response
def _call_llm_stream(self) -> tuple[str, List[Dict]]:
"""
Call LLM with streaming
Returns:
(response_text, tool_calls)
"""
# Trim messages if needed (using agent's context management)
self._trim_messages()
# Prepare messages
messages = self._prepare_messages()
# Debug: log message structure
logger.debug(f"Sending {len(messages)} messages to LLM")
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
if isinstance(content, list):
content_types = [c.get("type") for c in content if isinstance(c, dict)]
logger.debug(f" Message {i}: role={role}, content_blocks={content_types}")
else:
logger.debug(f" Message {i}: role={role}, content_length={len(str(content))}")
# Prepare tool definitions (OpenAI/Claude format)
tools_schema = None
if self.tools:
tools_schema = []
for tool in self.tools.values():
tools_schema.append({
"name": tool.name,
"description": tool.description,
"input_schema": tool.params # Claude uses input_schema
})
# Create request
request = LLMRequest(
messages=messages,
temperature=0,
stream=True,
tools=tools_schema,
system=self.system_prompt # Pass system prompt separately for Claude API
)
self._emit_event("message_start", {"role": "assistant"})
# Streaming response
full_content = ""
tool_calls_buffer = {} # {index: {id, name, arguments}}
try:
stream = self.model.call_stream(request)
for chunk in stream:
# Check for errors
if isinstance(chunk, dict) and chunk.get("error"):
error_msg = chunk.get("message", "Unknown error")
status_code = chunk.get("status_code", "N/A")
logger.error(f"API Error: {error_msg} (Status: {status_code})")
logger.error(f"Full error chunk: {chunk}")
raise Exception(f"{error_msg} (Status: {status_code})")
# Parse chunk
if isinstance(chunk, dict) and "choices" in chunk:
choice = chunk["choices"][0]
delta = choice.get("delta", {})
# Handle text content
if "content" in delta and delta["content"]:
content_delta = delta["content"]
full_content += content_delta
self._emit_event("message_update", {"delta": content_delta})
# Handle tool calls
if "tool_calls" in delta:
for tc_delta in delta["tool_calls"]:
index = tc_delta.get("index", 0)
if index not in tool_calls_buffer:
tool_calls_buffer[index] = {
"id": "",
"name": "",
"arguments": ""
}
if "id" in tc_delta:
tool_calls_buffer[index]["id"] = tc_delta["id"]
if "function" in tc_delta:
func = tc_delta["function"]
if "name" in func:
tool_calls_buffer[index]["name"] = func["name"]
if "arguments" in func:
tool_calls_buffer[index]["arguments"] += func["arguments"]
except Exception as e:
logger.error(f"LLM call error: {e}")
raise
# Parse tool calls
tool_calls = []
for idx in sorted(tool_calls_buffer.keys()):
tc = tool_calls_buffer[idx]
try:
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
except json.JSONDecodeError as e:
logger.error(f"Failed to parse tool arguments: {tc['arguments']}")
arguments = {}
tool_calls.append({
"id": tc["id"],
"name": tc["name"],
"arguments": arguments
})
# Add assistant message to history (Claude format uses content blocks)
assistant_msg = {"role": "assistant", "content": []}
# Add text content block if present
if full_content:
assistant_msg["content"].append({
"type": "text",
"text": full_content
})
# Add tool_use blocks if present
if tool_calls:
for tc in tool_calls:
assistant_msg["content"].append({
"type": "tool_use",
"id": tc["id"],
"name": tc["name"],
"input": tc["arguments"]
})
# Only append if content is not empty
if assistant_msg["content"]:
self.messages.append(assistant_msg)
self._emit_event("message_end", {
"content": full_content,
"tool_calls": tool_calls
})
return full_content, tool_calls
def _execute_tool(self, tool_call: Dict) -> Dict[str, Any]:
"""
Execute tool
Args:
tool_call: {"id": str, "name": str, "arguments": dict}
Returns:
Tool execution result
"""
tool_name = tool_call["name"]
tool_id = tool_call["id"]
arguments = tool_call["arguments"]
self._emit_event("tool_execution_start", {
"tool_call_id": tool_id,
"tool_name": tool_name,
"arguments": arguments
})
try:
tool = self.tools.get(tool_name)
if not tool:
raise ValueError(f"Tool '{tool_name}' not found")
# Set tool context
tool.model = self.model
tool.context = self.agent
# Execute tool
start_time = time.time()
result: ToolResult = tool.execute_tool(arguments)
execution_time = time.time() - start_time
result_dict = {
"status": result.status,
"result": result.result,
"execution_time": execution_time
}
# Auto-refresh skills after skill creation
if tool_name == "bash" and result.status == "success":
command = arguments.get("command", "")
if "init_skill.py" in command and self.agent.skill_manager:
logger.info("🔄 Detected skill creation, refreshing skills...")
self.agent.refresh_skills()
logger.info(f"✅ Skills refreshed! Now have {len(self.agent.skill_manager.skills)} skills")
self._emit_event("tool_execution_end", {
"tool_call_id": tool_id,
"tool_name": tool_name,
**result_dict
})
return result_dict
except Exception as e:
logger.error(f"Tool execution error: {e}")
error_result = {
"status": "error",
"result": str(e),
"execution_time": 0
}
self._emit_event("tool_execution_end", {
"tool_call_id": tool_id,
"tool_name": tool_name,
**error_result
})
return error_result
def _trim_messages(self):
"""
Trim message history to stay within context limits.
Uses agent's context management configuration.
"""
if not self.messages or not self.agent:
return
# Get context window and reserve tokens from agent
context_window = self.agent._get_model_context_window()
reserve_tokens = self.agent._get_context_reserve_tokens()
max_tokens = context_window - reserve_tokens
# Estimate current tokens
current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages)
# Add system prompt tokens
system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt})
current_tokens += system_tokens
# If under limit, no need to trim
if current_tokens <= max_tokens:
return
# Keep messages from newest, accumulating tokens
available_tokens = max_tokens - system_tokens
kept_messages = []
accumulated_tokens = 0
for msg in reversed(self.messages):
msg_tokens = self.agent._estimate_message_tokens(msg)
if accumulated_tokens + msg_tokens <= available_tokens:
kept_messages.insert(0, msg)
accumulated_tokens += msg_tokens
else:
break
old_count = len(self.messages)
self.messages = kept_messages
new_count = len(self.messages)
if old_count > new_count:
logger.info(
f"Context trimmed: {old_count} -> {new_count} messages "
f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, "
f"limit: {max_tokens})"
)
def _prepare_messages(self) -> List[Dict[str, Any]]:
"""
Prepare messages to send to LLM
Note: For Claude API, system prompt should be passed separately via system parameter,
not as a message. The AgentLLMModel will handle this.
"""
# Don't add system message here - it will be handled separately by the LLM adapter
return self.messages

27
agent/protocol/context.py Normal file
View File

@@ -0,0 +1,27 @@
class TeamContext:
def __init__(self, name: str, description: str, rule: str, agents: list, max_steps: int = 100):
"""
Initialize the TeamContext with a name, description, rules, a list of agents, and a user question.
:param name: The name of the group context.
:param description: A description of the group context.
:param rule: The rules governing the group context.
:param agents: A list of agents in the context.
"""
self.name = name
self.description = description
self.rule = rule
self.agents = agents
self.user_task = "" # For backward compatibility
self.task = None # Will be a Task instance
self.model = None # Will be an instance of LLMModel
self.task_short_name = None # Store the task directory name
# List of agents that have been executed
self.agent_outputs: list = []
self.current_steps = 0
self.max_steps = max_steps
class AgentOutput:
def __init__(self, agent_name: str, output: str):
self.agent_name = agent_name
self.output = output

57
agent/protocol/models.py Normal file
View File

@@ -0,0 +1,57 @@
"""
Models module for agent system.
Provides basic model classes needed by tools and bridge integration.
"""
from typing import Any, Dict, List, Optional
class LLMRequest:
"""Request model for LLM operations"""
def __init__(self, messages: List[Dict[str, str]] = None, model: Optional[str] = None,
temperature: float = 0.7, max_tokens: Optional[int] = None,
stream: bool = False, tools: Optional[List] = None, **kwargs):
self.messages = messages or []
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.stream = stream
self.tools = tools
# Allow extra attributes
for key, value in kwargs.items():
setattr(self, key, value)
class LLMModel:
"""Base class for LLM models"""
def __init__(self, model: str = None, **kwargs):
self.model = model
self.config = kwargs
def call(self, request: LLMRequest):
"""
Call the model with a request.
This is a placeholder implementation.
"""
raise NotImplementedError("LLMModel.call not implemented in this context")
def call_stream(self, request: LLMRequest):
"""
Call the model with streaming.
This is a placeholder implementation.
"""
raise NotImplementedError("LLMModel.call_stream not implemented in this context")
class ModelFactory:
"""Factory for creating model instances"""
@staticmethod
def create_model(model_type: str, **kwargs):
"""
Create a model instance based on type.
This is a placeholder implementation.
"""
raise NotImplementedError("ModelFactory.create_model not implemented in this context")

96
agent/protocol/result.py Normal file
View File

@@ -0,0 +1,96 @@
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Dict, Any, Optional
from agent.protocol.task import Task, TaskStatus
class AgentActionType(Enum):
"""Enum representing different types of agent actions."""
TOOL_USE = "tool_use"
THINKING = "thinking"
FINAL_ANSWER = "final_answer"
@dataclass
class ToolResult:
"""
Represents the result of a tool use.
Attributes:
tool_name: Name of the tool used
input_params: Parameters passed to the tool
output: Output from the tool
status: Status of the tool execution (success/error)
error_message: Error message if the tool execution failed
execution_time: Time taken to execute the tool
"""
tool_name: str
input_params: Dict[str, Any]
output: Any
status: str
error_message: Optional[str] = None
execution_time: float = 0.0
@dataclass
class AgentAction:
"""
Represents an action taken by an agent.
Attributes:
id: Unique identifier for the action
agent_id: ID of the agent that performed the action
agent_name: Name of the agent that performed the action
action_type: Type of action (tool use, thinking, final answer)
content: Content of the action (thought content, final answer content)
tool_result: Tool use details if action_type is TOOL_USE
timestamp: When the action was performed
"""
agent_id: str
agent_name: str
action_type: AgentActionType
id: str = field(default_factory=lambda: str(uuid.uuid4()))
content: str = ""
tool_result: Optional[ToolResult] = None
thought: Optional[str] = None
timestamp: float = field(default_factory=time.time)
@dataclass
class AgentResult:
"""
Represents the result of an agent's execution.
Attributes:
final_answer: The final answer provided by the agent
step_count: Number of steps taken by the agent
status: Status of the execution (success/error)
error_message: Error message if execution failed
"""
final_answer: str
step_count: int
status: str = "success"
error_message: Optional[str] = None
@classmethod
def success(cls, final_answer: str, step_count: int) -> "AgentResult":
"""Create a successful result"""
return cls(final_answer=final_answer, step_count=step_count)
@classmethod
def error(cls, error_message: str, step_count: int = 0) -> "AgentResult":
"""Create an error result"""
return cls(
final_answer=f"Error: {error_message}",
step_count=step_count,
status="error",
error_message=error_message
)
@property
def is_error(self) -> bool:
"""Check if the result represents an error"""
return self.status == "error"

95
agent/protocol/task.py Normal file
View File

@@ -0,0 +1,95 @@
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Any, List
class TaskType(Enum):
"""Enum representing different types of tasks."""
TEXT = "text"
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"
FILE = "file"
MIXED = "mixed"
class TaskStatus(Enum):
"""Enum representing the status of a task."""
INIT = "init" # Initial state
PROCESSING = "processing" # In progress
COMPLETED = "completed" # Completed
FAILED = "failed" # Failed
@dataclass
class Task:
"""
Represents a task to be processed by an agent.
Attributes:
id: Unique identifier for the task
content: The primary text content of the task
type: Type of the task
status: Current status of the task
created_at: Timestamp when the task was created
updated_at: Timestamp when the task was last updated
metadata: Additional metadata for the task
images: List of image URLs or base64 encoded images
videos: List of video URLs
audios: List of audio URLs or base64 encoded audios
files: List of file URLs or paths
"""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
content: str = ""
type: TaskType = TaskType.TEXT
status: TaskStatus = TaskStatus.INIT
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
metadata: Dict[str, Any] = field(default_factory=dict)
# Media content
images: List[str] = field(default_factory=list)
videos: List[str] = field(default_factory=list)
audios: List[str] = field(default_factory=list)
files: List[str] = field(default_factory=list)
def __init__(self, content: str = "", **kwargs):
"""
Initialize a Task with content and optional keyword arguments.
Args:
content: The text content of the task
**kwargs: Additional attributes to set
"""
self.id = kwargs.get('id', str(uuid.uuid4()))
self.content = content
self.type = kwargs.get('type', TaskType.TEXT)
self.status = kwargs.get('status', TaskStatus.INIT)
self.created_at = kwargs.get('created_at', time.time())
self.updated_at = kwargs.get('updated_at', time.time())
self.metadata = kwargs.get('metadata', {})
self.images = kwargs.get('images', [])
self.videos = kwargs.get('videos', [])
self.audios = kwargs.get('audios', [])
self.files = kwargs.get('files', [])
def get_text(self) -> str:
"""
Get the text content of the task.
Returns:
The text content
"""
return self.content
def update_status(self, status: TaskStatus) -> None:
"""
Update the status of the task.
Args:
status: The new status
"""
self.status = status
self.updated_at = time.time()

29
agent/skills/__init__.py Normal file
View File

@@ -0,0 +1,29 @@
"""
Skills module for agent system.
This module provides the framework for loading, managing, and executing skills.
Skills are markdown files with frontmatter that provide specialized instructions
for specific tasks.
"""
from agent.skills.types import (
Skill,
SkillEntry,
SkillMetadata,
SkillInstallSpec,
LoadSkillsResult,
)
from agent.skills.loader import SkillLoader
from agent.skills.manager import SkillManager
from agent.skills.formatter import format_skills_for_prompt
__all__ = [
"Skill",
"SkillEntry",
"SkillMetadata",
"SkillInstallSpec",
"LoadSkillsResult",
"SkillLoader",
"SkillManager",
"format_skills_for_prompt",
]

211
agent/skills/config.py Normal file
View File

@@ -0,0 +1,211 @@
"""
Configuration support for skills.
"""
import os
import platform
from typing import Dict, Optional, List
from agent.skills.types import SkillEntry
def resolve_runtime_platform() -> str:
"""Get the current runtime platform."""
return platform.system().lower()
def has_binary(bin_name: str) -> bool:
"""
Check if a binary is available in PATH.
:param bin_name: Binary name to check
:return: True if binary is available
"""
import shutil
return shutil.which(bin_name) is not None
def has_any_binary(bin_names: List[str]) -> bool:
"""
Check if any of the given binaries is available.
:param bin_names: List of binary names to check
:return: True if at least one binary is available
"""
return any(has_binary(bin_name) for bin_name in bin_names)
def has_env_var(env_name: str) -> bool:
"""
Check if an environment variable is set.
:param env_name: Environment variable name
:return: True if environment variable is set
"""
return env_name in os.environ and bool(os.environ[env_name].strip())
def get_skill_config(config: Optional[Dict], skill_name: str) -> Optional[Dict]:
"""
Get skill-specific configuration.
:param config: Global configuration dictionary
:param skill_name: Name of the skill
:return: Skill configuration or None
"""
if not config:
return None
skills_config = config.get('skills', {})
if not isinstance(skills_config, dict):
return None
entries = skills_config.get('entries', {})
if not isinstance(entries, dict):
return None
return entries.get(skill_name)
def should_include_skill(
entry: SkillEntry,
config: Optional[Dict] = None,
current_platform: Optional[str] = None,
lenient: bool = True,
) -> bool:
"""
Determine if a skill should be included based on requirements.
Similar to clawdbot's shouldIncludeSkill logic, but with lenient mode:
- In lenient mode (default): Only check explicit disable and platform, ignore missing requirements
- In strict mode: Check all requirements (binary, env vars, config)
:param entry: SkillEntry to check
:param config: Configuration dictionary
:param current_platform: Current platform (default: auto-detect)
:param lenient: If True, ignore missing requirements and load all skills (default: True)
:return: True if skill should be included
"""
metadata = entry.metadata
skill_name = entry.skill.name
skill_config = get_skill_config(config, skill_name)
# Always check if skill is explicitly disabled in config
if skill_config and skill_config.get('enabled') is False:
return False
if not metadata:
return True
# Always check platform requirements (can't work on wrong platform)
if metadata.os:
platform_name = current_platform or resolve_runtime_platform()
# Map common platform names
platform_map = {
'darwin': 'darwin',
'linux': 'linux',
'windows': 'win32',
}
normalized_platform = platform_map.get(platform_name, platform_name)
if normalized_platform not in metadata.os:
return False
# If skill has 'always: true', include it regardless of other requirements
if metadata.always:
return True
# In lenient mode, skip requirement checks and load all skills
# Skills will fail gracefully at runtime if requirements are missing
if lenient:
return True
# Strict mode: Check all requirements
if metadata.requires:
# Check required binaries (all must be present)
required_bins = metadata.requires.get('bins', [])
if required_bins:
if not all(has_binary(bin_name) for bin_name in required_bins):
return False
# Check anyBins (at least one must be present)
any_bins = metadata.requires.get('anyBins', [])
if any_bins:
if not has_any_binary(any_bins):
return False
# Check environment variables (with config fallback)
required_env = metadata.requires.get('env', [])
if required_env:
for env_name in required_env:
# Check in order: 1) env var, 2) skill config env, 3) skill config apiKey (if primaryEnv)
if has_env_var(env_name):
continue
if skill_config:
# Check skill config env dict
skill_env = skill_config.get('env', {})
if isinstance(skill_env, dict) and env_name in skill_env:
continue
# Check skill config apiKey (if this is the primaryEnv)
if metadata.primary_env == env_name and skill_config.get('apiKey'):
continue
# Requirement not satisfied
return False
# Check config paths
required_config = metadata.requires.get('config', [])
if required_config and config:
for config_path in required_config:
if not is_config_path_truthy(config, config_path):
return False
return True
def is_config_path_truthy(config: Dict, path: str) -> bool:
"""
Check if a config path resolves to a truthy value.
:param config: Configuration dictionary
:param path: Dot-separated path (e.g., 'skills.enabled')
:return: True if path resolves to truthy value
"""
parts = path.split('.')
current = config
for part in parts:
if not isinstance(current, dict):
return False
current = current.get(part)
if current is None:
return False
# Check if value is truthy
if isinstance(current, bool):
return current
if isinstance(current, (int, float)):
return current != 0
if isinstance(current, str):
return bool(current.strip())
return bool(current)
def resolve_config_path(config: Dict, path: str):
"""
Resolve a dot-separated config path to its value.
:param config: Configuration dictionary
:param path: Dot-separated path
:return: Value at path or None
"""
parts = path.split('.')
current = config
for part in parts:
if not isinstance(current, dict):
return None
current = current.get(part)
if current is None:
return None
return current

62
agent/skills/formatter.py Normal file
View File

@@ -0,0 +1,62 @@
"""
Skill formatter for generating prompts from skills.
"""
from typing import List
from agent.skills.types import Skill, SkillEntry
def format_skills_for_prompt(skills: List[Skill]) -> str:
"""
Format skills for inclusion in a system prompt.
Uses XML format per Agent Skills standard.
Skills with disable_model_invocation=True are excluded.
:param skills: List of skills to format
:return: Formatted prompt text
"""
# Filter out skills that should not be invoked by the model
visible_skills = [s for s in skills if not s.disable_model_invocation]
if not visible_skills:
return ""
lines = [
"\n\nThe following skills provide specialized instructions for specific tasks.",
"Use the read tool to load a skill's file when the task matches its description.",
"",
"<available_skills>",
]
for skill in visible_skills:
lines.append(" <skill>")
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
lines.append(f" <location>{_escape_xml(skill.file_path)}</location>")
lines.append(" </skill>")
lines.append("</available_skills>")
return "\n".join(lines)
def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
"""
Format skill entries for inclusion in a system prompt.
:param entries: List of skill entries to format
:return: Formatted prompt text
"""
skills = [entry.skill for entry in entries]
return format_skills_for_prompt(skills)
def _escape_xml(text: str) -> str:
"""Escape XML special characters."""
return (text
.replace('&', '&amp;')
.replace('<', '&lt;')
.replace('>', '&gt;')
.replace('"', '&quot;')
.replace("'", '&apos;'))

159
agent/skills/frontmatter.py Normal file
View File

@@ -0,0 +1,159 @@
"""
Frontmatter parsing for skills.
"""
import re
import json
from typing import Dict, Any, Optional, List
from agent.skills.types import SkillMetadata, SkillInstallSpec
def parse_frontmatter(content: str) -> Dict[str, Any]:
"""
Parse YAML-style frontmatter from markdown content.
Returns a dictionary of frontmatter fields.
"""
frontmatter = {}
# Match frontmatter block between --- markers
match = re.match(r'^---\s*\n(.*?)\n---\s*\n', content, re.DOTALL)
if not match:
return frontmatter
frontmatter_text = match.group(1)
# Simple YAML-like parsing (supports key: value format)
for line in frontmatter_text.split('\n'):
line = line.strip()
if not line or line.startswith('#'):
continue
if ':' in line:
key, value = line.split(':', 1)
key = key.strip()
value = value.strip()
# Try to parse as JSON if it looks like JSON
if value.startswith('{') or value.startswith('['):
try:
value = json.loads(value)
except json.JSONDecodeError:
pass
# Parse boolean values
elif value.lower() in ('true', 'false'):
value = value.lower() == 'true'
# Parse numbers
elif value.isdigit():
value = int(value)
frontmatter[key] = value
return frontmatter
def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
"""
Parse skill metadata from frontmatter.
Looks for 'metadata' field containing JSON with skill configuration.
"""
metadata_raw = frontmatter.get('metadata')
if not metadata_raw:
return None
# If it's a string, try to parse as JSON
if isinstance(metadata_raw, str):
try:
metadata_raw = json.loads(metadata_raw)
except json.JSONDecodeError:
return None
if not isinstance(metadata_raw, dict):
return None
# Support both 'moltbot' and 'cow' keys for compatibility
meta_obj = metadata_raw.get('moltbot') or metadata_raw.get('cow')
if not meta_obj or not isinstance(meta_obj, dict):
return None
# Parse install specs
install_specs = []
install_raw = meta_obj.get('install', [])
if isinstance(install_raw, list):
for spec_raw in install_raw:
if not isinstance(spec_raw, dict):
continue
kind = spec_raw.get('kind', spec_raw.get('type', '')).lower()
if not kind:
continue
spec = SkillInstallSpec(
kind=kind,
id=spec_raw.get('id'),
label=spec_raw.get('label'),
bins=_normalize_string_list(spec_raw.get('bins')),
os=_normalize_string_list(spec_raw.get('os')),
formula=spec_raw.get('formula'),
package=spec_raw.get('package'),
module=spec_raw.get('module'),
url=spec_raw.get('url'),
archive=spec_raw.get('archive'),
extract=spec_raw.get('extract', False),
strip_components=spec_raw.get('stripComponents'),
target_dir=spec_raw.get('targetDir'),
)
install_specs.append(spec)
# Parse requires
requires = {}
requires_raw = meta_obj.get('requires', {})
if isinstance(requires_raw, dict):
for key, value in requires_raw.items():
requires[key] = _normalize_string_list(value)
return SkillMetadata(
always=meta_obj.get('always', False),
skill_key=meta_obj.get('skillKey'),
primary_env=meta_obj.get('primaryEnv'),
emoji=meta_obj.get('emoji'),
homepage=meta_obj.get('homepage'),
os=_normalize_string_list(meta_obj.get('os')),
requires=requires,
install=install_specs,
)
def _normalize_string_list(value: Any) -> List[str]:
"""Normalize a value to a list of strings."""
if not value:
return []
if isinstance(value, list):
return [str(v).strip() for v in value if v]
if isinstance(value, str):
return [v.strip() for v in value.split(',') if v.strip()]
return []
def parse_boolean_value(value: Optional[str], default: bool = False) -> bool:
"""Parse a boolean value from frontmatter."""
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ('true', '1', 'yes', 'on')
return default
def get_frontmatter_value(frontmatter: Dict[str, Any], key: str) -> Optional[str]:
"""Get a frontmatter value as a string."""
value = frontmatter.get(key)
return str(value) if value is not None else None

242
agent/skills/loader.py Normal file
View File

@@ -0,0 +1,242 @@
"""
Skill loader for discovering and loading skills from directories.
"""
import os
from pathlib import Path
from typing import List, Optional, Dict
from common.log import logger
from agent.skills.types import Skill, SkillEntry, LoadSkillsResult, SkillMetadata
from agent.skills.frontmatter import parse_frontmatter, parse_metadata, parse_boolean_value, get_frontmatter_value
class SkillLoader:
"""Loads skills from various directories."""
def __init__(self, workspace_dir: Optional[str] = None):
"""
Initialize the skill loader.
:param workspace_dir: Agent workspace directory (for workspace-specific skills)
"""
self.workspace_dir = workspace_dir
def load_skills_from_dir(self, dir_path: str, source: str) -> LoadSkillsResult:
"""
Load skills from a directory.
Discovery rules:
- Direct .md files in the root directory
- Recursive SKILL.md files under subdirectories
:param dir_path: Directory path to scan
:param source: Source identifier (e.g., 'managed', 'workspace', 'bundled')
:return: LoadSkillsResult with skills and diagnostics
"""
skills = []
diagnostics = []
if not os.path.exists(dir_path):
diagnostics.append(f"Directory does not exist: {dir_path}")
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
if not os.path.isdir(dir_path):
diagnostics.append(f"Path is not a directory: {dir_path}")
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
# Load skills from root-level .md files and subdirectories
result = self._load_skills_recursive(dir_path, source, include_root_files=True)
return result
def _load_skills_recursive(
self,
dir_path: str,
source: str,
include_root_files: bool = False
) -> LoadSkillsResult:
"""
Recursively load skills from a directory.
:param dir_path: Directory to scan
:param source: Source identifier
:param include_root_files: Whether to include root-level .md files
:return: LoadSkillsResult
"""
skills = []
diagnostics = []
try:
entries = os.listdir(dir_path)
except Exception as e:
diagnostics.append(f"Failed to list directory {dir_path}: {e}")
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
for entry in entries:
# Skip hidden files and directories
if entry.startswith('.'):
continue
# Skip common non-skill directories
if entry in ('node_modules', '__pycache__', 'venv', '.git'):
continue
full_path = os.path.join(dir_path, entry)
# Handle directories
if os.path.isdir(full_path):
# Recursively scan subdirectories
sub_result = self._load_skills_recursive(full_path, source, include_root_files=False)
skills.extend(sub_result.skills)
diagnostics.extend(sub_result.diagnostics)
continue
# Handle files
if not os.path.isfile(full_path):
continue
# Check if this is a skill file
is_root_md = include_root_files and entry.endswith('.md')
is_skill_md = not include_root_files and entry == 'SKILL.md'
if not (is_root_md or is_skill_md):
continue
# Load the skill
skill_result = self._load_skill_from_file(full_path, source)
if skill_result.skills:
skills.extend(skill_result.skills)
diagnostics.extend(skill_result.diagnostics)
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
def _load_skill_from_file(self, file_path: str, source: str) -> LoadSkillsResult:
"""
Load a single skill from a markdown file.
:param file_path: Path to the skill markdown file
:param source: Source identifier
:return: LoadSkillsResult
"""
diagnostics = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
except Exception as e:
diagnostics.append(f"Failed to read skill file {file_path}: {e}")
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
# Parse frontmatter
frontmatter = parse_frontmatter(content)
# Get skill name and description
skill_dir = os.path.dirname(file_path)
parent_dir_name = os.path.basename(skill_dir)
name = frontmatter.get('name', parent_dir_name)
description = frontmatter.get('description', '')
if not description or not description.strip():
diagnostics.append(f"Skill {name} has no description: {file_path}")
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
# Parse disable-model-invocation flag
disable_model_invocation = parse_boolean_value(
get_frontmatter_value(frontmatter, 'disable-model-invocation'),
default=False
)
# Create skill object
skill = Skill(
name=name,
description=description,
file_path=file_path,
base_dir=skill_dir,
source=source,
content=content,
disable_model_invocation=disable_model_invocation,
frontmatter=frontmatter,
)
return LoadSkillsResult(skills=[skill], diagnostics=diagnostics)
def load_all_skills(
self,
managed_dir: Optional[str] = None,
workspace_skills_dir: Optional[str] = None,
extra_dirs: Optional[List[str]] = None,
) -> Dict[str, SkillEntry]:
"""
Load skills from all configured locations with precedence.
Precedence (lowest to highest):
1. Extra directories
2. Managed skills directory
3. Workspace skills directory
:param managed_dir: Managed skills directory (e.g., ~/.cow/skills)
:param workspace_skills_dir: Workspace skills directory (e.g., workspace/skills)
:param extra_dirs: Additional directories to load skills from
:return: Dictionary mapping skill name to SkillEntry
"""
skill_map: Dict[str, SkillEntry] = {}
all_diagnostics = []
# Load from extra directories (lowest precedence)
if extra_dirs:
for extra_dir in extra_dirs:
if not os.path.exists(extra_dir):
continue
result = self.load_skills_from_dir(extra_dir, source='extra')
all_diagnostics.extend(result.diagnostics)
for skill in result.skills:
entry = self._create_skill_entry(skill)
skill_map[skill.name] = entry
# Load from managed directory
if managed_dir and os.path.exists(managed_dir):
result = self.load_skills_from_dir(managed_dir, source='managed')
all_diagnostics.extend(result.diagnostics)
for skill in result.skills:
entry = self._create_skill_entry(skill)
skill_map[skill.name] = entry
# Load from workspace directory (highest precedence)
if workspace_skills_dir and os.path.exists(workspace_skills_dir):
result = self.load_skills_from_dir(workspace_skills_dir, source='workspace')
all_diagnostics.extend(result.diagnostics)
for skill in result.skills:
entry = self._create_skill_entry(skill)
skill_map[skill.name] = entry
# Log diagnostics
if all_diagnostics:
logger.debug(f"Skill loading diagnostics: {len(all_diagnostics)} issues")
for diag in all_diagnostics[:5]: # Log first 5
logger.debug(f" - {diag}")
logger.info(f"Loaded {len(skill_map)} skills from all sources")
return skill_map
def _create_skill_entry(self, skill: Skill) -> SkillEntry:
"""
Create a SkillEntry from a Skill with parsed metadata.
:param skill: The skill to create an entry for
:return: SkillEntry with metadata
"""
metadata = parse_metadata(skill.frontmatter)
# Parse user-invocable flag
user_invocable = parse_boolean_value(
get_frontmatter_value(skill.frontmatter, 'user-invocable'),
default=True
)
return SkillEntry(
skill=skill,
metadata=metadata,
user_invocable=user_invocable,
)

214
agent/skills/manager.py Normal file
View File

@@ -0,0 +1,214 @@
"""
Skill manager for managing skill lifecycle and operations.
"""
import os
from typing import Dict, List, Optional
from pathlib import Path
from common.log import logger
from agent.skills.types import Skill, SkillEntry, SkillSnapshot
from agent.skills.loader import SkillLoader
from agent.skills.formatter import format_skill_entries_for_prompt
class SkillManager:
"""Manages skills for an agent."""
def __init__(
self,
workspace_dir: Optional[str] = None,
managed_skills_dir: Optional[str] = None,
extra_dirs: Optional[List[str]] = None,
config: Optional[Dict] = None,
):
"""
Initialize the skill manager.
:param workspace_dir: Agent workspace directory
:param managed_skills_dir: Managed skills directory (e.g., ~/.cow/skills)
:param extra_dirs: Additional skill directories
:param config: Configuration dictionary
"""
self.workspace_dir = workspace_dir
self.managed_skills_dir = managed_skills_dir or self._get_default_managed_dir()
self.extra_dirs = extra_dirs or []
self.config = config or {}
self.loader = SkillLoader(workspace_dir=workspace_dir)
self.skills: Dict[str, SkillEntry] = {}
# Load skills on initialization
self.refresh_skills()
def _get_default_managed_dir(self) -> str:
"""Get the default managed skills directory."""
# Use project root skills directory as default
import os
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
return os.path.join(project_root, 'skills')
def refresh_skills(self):
"""Reload all skills from configured directories."""
workspace_skills_dir = None
if self.workspace_dir:
workspace_skills_dir = os.path.join(self.workspace_dir, 'skills')
self.skills = self.loader.load_all_skills(
managed_dir=self.managed_skills_dir,
workspace_skills_dir=workspace_skills_dir,
extra_dirs=self.extra_dirs,
)
logger.info(f"SkillManager: Loaded {len(self.skills)} skills")
def get_skill(self, name: str) -> Optional[SkillEntry]:
"""
Get a skill by name.
:param name: Skill name
:return: SkillEntry or None if not found
"""
return self.skills.get(name)
def list_skills(self) -> List[SkillEntry]:
"""
Get all loaded skills.
:return: List of all skill entries
"""
return list(self.skills.values())
def filter_skills(
self,
skill_filter: Optional[List[str]] = None,
include_disabled: bool = False,
check_requirements: bool = False, # Changed default to False for lenient loading
lenient: bool = True, # New parameter for lenient mode
) -> List[SkillEntry]:
"""
Filter skills based on criteria.
By default (lenient=True), all skills are loaded regardless of missing requirements.
Skills will fail gracefully at runtime if requirements are not met.
:param skill_filter: List of skill names to include (None = all)
:param include_disabled: Whether to include skills with disable_model_invocation=True
:param check_requirements: Whether to check skill requirements (default: False)
:param lenient: If True, ignore missing requirements (default: True)
:return: Filtered list of skill entries
"""
from agent.skills.config import should_include_skill
entries = list(self.skills.values())
# Check requirements (platform, explicit disable, etc.)
# In lenient mode, only checks platform and explicit disable
if check_requirements or not lenient:
entries = [e for e in entries if should_include_skill(e, self.config, lenient=lenient)]
else:
# Lenient mode: only check explicit disable and platform
entries = [e for e in entries if should_include_skill(e, self.config, lenient=True)]
# Apply skill filter
if skill_filter is not None:
normalized = [name.strip() for name in skill_filter if name.strip()]
if normalized:
entries = [e for e in entries if e.skill.name in normalized]
# Filter out disabled skills unless explicitly requested
if not include_disabled:
entries = [e for e in entries if not e.skill.disable_model_invocation]
return entries
def build_skills_prompt(
self,
skill_filter: Optional[List[str]] = None,
) -> str:
"""
Build a formatted prompt containing available skills.
:param skill_filter: Optional list of skill names to include
:return: Formatted skills prompt
"""
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
return format_skill_entries_for_prompt(entries)
def build_skill_snapshot(
self,
skill_filter: Optional[List[str]] = None,
version: Optional[int] = None,
) -> SkillSnapshot:
"""
Build a snapshot of skills for a specific run.
:param skill_filter: Optional list of skill names to include
:param version: Optional version number for the snapshot
:return: SkillSnapshot
"""
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
prompt = format_skill_entries_for_prompt(entries)
skills_info = []
resolved_skills = []
for entry in entries:
skills_info.append({
'name': entry.skill.name,
'primary_env': entry.metadata.primary_env if entry.metadata else None,
})
resolved_skills.append(entry.skill)
return SkillSnapshot(
prompt=prompt,
skills=skills_info,
resolved_skills=resolved_skills,
version=version,
)
def sync_skills_to_workspace(self, target_workspace_dir: str):
"""
Sync all loaded skills to a target workspace directory.
This is useful for sandbox environments where skills need to be copied.
:param target_workspace_dir: Target workspace directory
"""
import shutil
target_skills_dir = os.path.join(target_workspace_dir, 'skills')
# Remove existing skills directory
if os.path.exists(target_skills_dir):
shutil.rmtree(target_skills_dir)
# Create new skills directory
os.makedirs(target_skills_dir, exist_ok=True)
# Copy each skill
for entry in self.skills.values():
skill_name = entry.skill.name
source_dir = entry.skill.base_dir
target_dir = os.path.join(target_skills_dir, skill_name)
try:
shutil.copytree(source_dir, target_dir)
logger.debug(f"Synced skill '{skill_name}' to {target_dir}")
except Exception as e:
logger.warning(f"Failed to sync skill '{skill_name}': {e}")
logger.info(f"Synced {len(self.skills)} skills to {target_skills_dir}")
def get_skill_by_key(self, skill_key: str) -> Optional[SkillEntry]:
"""
Get a skill by its skill key (which may differ from name).
:param skill_key: Skill key to look up
:return: SkillEntry or None
"""
for entry in self.skills.values():
if entry.metadata and entry.metadata.skill_key == skill_key:
return entry
if entry.skill.name == skill_key:
return entry
return None

74
agent/skills/types.py Normal file
View File

@@ -0,0 +1,74 @@
"""
Type definitions for skills system.
"""
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
@dataclass
class SkillInstallSpec:
"""Specification for installing skill dependencies."""
kind: str # brew, pip, npm, download, etc.
id: Optional[str] = None
label: Optional[str] = None
bins: List[str] = field(default_factory=list)
os: List[str] = field(default_factory=list)
formula: Optional[str] = None # for brew
package: Optional[str] = None # for pip/npm
module: Optional[str] = None
url: Optional[str] = None # for download
archive: Optional[str] = None
extract: bool = False
strip_components: Optional[int] = None
target_dir: Optional[str] = None
@dataclass
class SkillMetadata:
"""Metadata for a skill from frontmatter."""
always: bool = False # Always include this skill
skill_key: Optional[str] = None # Override skill key
primary_env: Optional[str] = None # Primary environment variable
emoji: Optional[str] = None
homepage: Optional[str] = None
os: List[str] = field(default_factory=list) # Supported OS platforms
requires: Dict[str, List[str]] = field(default_factory=dict) # Requirements
install: List[SkillInstallSpec] = field(default_factory=list)
@dataclass
class Skill:
"""Represents a skill loaded from a markdown file."""
name: str
description: str
file_path: str
base_dir: str
source: str # managed, workspace, bundled, etc.
content: str # Full markdown content
disable_model_invocation: bool = False
frontmatter: Dict[str, Any] = field(default_factory=dict)
@dataclass
class SkillEntry:
"""A skill with parsed metadata."""
skill: Skill
metadata: Optional[SkillMetadata] = None
user_invocable: bool = True # Can users invoke this skill directly
@dataclass
class LoadSkillsResult:
"""Result of loading skills from a directory."""
skills: List[Skill]
diagnostics: List[str] = field(default_factory=list)
@dataclass
class SkillSnapshot:
"""Snapshot of skills for a specific run."""
prompt: str # Formatted prompt text
skills: List[Dict[str, str]] # List of skill info (name, primary_env)
resolved_skills: List[Skill] = field(default_factory=list)
version: Optional[int] = None

105
agent/tools/__init__.py Normal file
View File

@@ -0,0 +1,105 @@
# Import base tool
from agent.tools.base_tool import BaseTool
from agent.tools.tool_manager import ToolManager
# Import basic tools (no external dependencies)
from agent.tools.calculator.calculator import Calculator
from agent.tools.current_time.current_time import CurrentTime
# Import file operation tools
from agent.tools.read.read import Read
from agent.tools.write.write import Write
from agent.tools.edit.edit import Edit
from agent.tools.bash.bash import Bash
from agent.tools.grep.grep import Grep
from agent.tools.find.find import Find
from agent.tools.ls.ls import Ls
# Import memory tools
from agent.tools.memory.memory_search import MemorySearchTool
from agent.tools.memory.memory_get import MemoryGetTool
# Import web tools
from agent.tools.web_fetch.web_fetch import WebFetch
# Import tools with optional dependencies
def _import_optional_tools():
"""Import tools that have optional dependencies"""
tools = {}
# Google Search (requires requests)
try:
from agent.tools.google_search.google_search import GoogleSearch
tools['GoogleSearch'] = GoogleSearch
except ImportError:
pass
# File Save (may have dependencies)
try:
from agent.tools.file_save.file_save import FileSave
tools['FileSave'] = FileSave
except ImportError:
pass
# Terminal (basic, should work)
try:
from agent.tools.terminal.terminal import Terminal
tools['Terminal'] = Terminal
except ImportError:
pass
return tools
# Load optional tools
_optional_tools = _import_optional_tools()
GoogleSearch = _optional_tools.get('GoogleSearch')
FileSave = _optional_tools.get('FileSave')
Terminal = _optional_tools.get('Terminal')
# Delayed import for BrowserTool
def _import_browser_tool():
try:
from agent.tools.browser.browser_tool import BrowserTool
return BrowserTool
except ImportError:
# Return a placeholder class that will prompt the user to install dependencies when instantiated
class BrowserToolPlaceholder:
def __init__(self, *args, **kwargs):
raise ImportError(
"The 'browser-use' package is required to use BrowserTool. "
"Please install it with 'pip install browser-use>=0.1.40'."
)
return BrowserToolPlaceholder
# Dynamically set BrowserTool
BrowserTool = _import_browser_tool()
# Export all tools (including optional ones that might be None)
__all__ = [
'BaseTool',
'ToolManager',
'Calculator',
'CurrentTime',
'Read',
'Write',
'Edit',
'Bash',
'Grep',
'Find',
'Ls',
'MemorySearchTool',
'MemoryGetTool',
'WebFetch',
# Optional tools (may be None if dependencies not available)
'GoogleSearch',
'FileSave',
'Terminal',
'BrowserTool'
]
"""
Tools module for Agent.
"""

99
agent/tools/base_tool.py Normal file
View File

@@ -0,0 +1,99 @@
from enum import Enum
from typing import Any, Optional
from common.log import logger
import copy
class ToolStage(Enum):
"""Enum representing tool decision stages"""
PRE_PROCESS = "pre_process" # Tools that need to be actively selected by the agent
POST_PROCESS = "post_process" # Tools that automatically execute after final_answer
class ToolResult:
"""Tool execution result"""
def __init__(self, status: str = None, result: Any = None, ext_data: Any = None):
self.status = status
self.result = result
self.ext_data = ext_data
@staticmethod
def success(result, ext_data: Any = None):
return ToolResult(status="success", result=result, ext_data=ext_data)
@staticmethod
def fail(result, ext_data: Any = None):
return ToolResult(status="error", result=result, ext_data=ext_data)
class BaseTool:
"""Base class for all tools."""
# Default decision stage is pre-process
stage = ToolStage.PRE_PROCESS
# Class attributes must be inherited
name: str = "base_tool"
description: str = "Base tool"
params: dict = {} # Store JSON Schema
model: Optional[Any] = None # LLM model instance, type depends on bot implementation
@classmethod
def get_json_schema(cls) -> dict:
"""Get the standard description of the tool"""
return {
"name": cls.name,
"description": cls.description,
"parameters": cls.params
}
def execute_tool(self, params: dict) -> ToolResult:
try:
return self.execute(params)
except Exception as e:
logger.error(e)
def execute(self, params: dict) -> ToolResult:
"""Specific logic to be implemented by subclasses"""
raise NotImplementedError
@classmethod
def _parse_schema(cls) -> dict:
"""Convert JSON Schema to Pydantic fields"""
fields = {}
for name, prop in cls.params["properties"].items():
# Convert JSON Schema types to Python types
type_map = {
"string": str,
"number": float,
"integer": int,
"boolean": bool,
"array": list,
"object": dict
}
fields[name] = (
type_map[prop["type"]],
prop.get("default", ...)
)
return fields
def should_auto_execute(self, context) -> bool:
"""
Determine if this tool should be automatically executed based on context.
:param context: The agent context
:return: True if the tool should be executed, False otherwise
"""
# Only tools in post-process stage will be automatically executed
return self.stage == ToolStage.POST_PROCESS
def close(self):
"""
Close any resources used by the tool.
This method should be overridden by tools that need to clean up resources
such as browser connections, file handles, etc.
By default, this method does nothing.
"""
pass

View File

@@ -0,0 +1,3 @@
from .bash import Bash
__all__ = ['Bash']

186
agent/tools/bash/bash.py Normal file
View File

@@ -0,0 +1,186 @@
"""
Bash tool - Execute bash commands
"""
import os
import subprocess
import tempfile
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
class Bash(BaseTool):
"""Tool for executing bash commands"""
name: str = "bash"
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
IMPORTANT SAFETY GUIDELINES:
- You can freely create, modify, and delete files within the current workspace
- For operations outside the workspace or potentially destructive commands (rm -rf, system commands, etc.), always explain what you're about to do and ask for user confirmation first
- When in doubt, describe the command's purpose and ask for permission before executing"""
params: dict = {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Bash command to execute"
},
"timeout": {
"type": "integer",
"description": "Timeout in seconds (optional, default: 30)"
}
},
"required": ["command"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
# Ensure working directory exists
if not os.path.exists(self.cwd):
os.makedirs(self.cwd, exist_ok=True)
self.default_timeout = self.config.get("timeout", 30)
# Enable safety mode by default (can be disabled in config)
self.safety_mode = self.config.get("safety_mode", True)
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute a bash command
:param args: Dictionary containing the command and optional timeout
:return: Command output or error
"""
command = args.get("command", "").strip()
timeout = args.get("timeout", self.default_timeout)
if not command:
return ToolResult.fail("Error: command parameter is required")
# Optional safety check - only warn about extremely dangerous commands
if self.safety_mode:
warning = self._get_safety_warning(command)
if warning:
return ToolResult.fail(
f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.")
try:
# Execute command
result = subprocess.run(
command,
shell=True,
cwd=self.cwd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=timeout
)
# Combine stdout and stderr
output = result.stdout
if result.stderr:
output += "\n" + result.stderr
# Check if we need to save full output to temp file
temp_file_path = None
total_bytes = len(output.encode('utf-8'))
if total_bytes > DEFAULT_MAX_BYTES:
# Save full output to temp file
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log', prefix='bash-') as f:
f.write(output)
temp_file_path = f.name
# Apply tail truncation
truncation = truncate_tail(output)
output_text = truncation.content or "(no output)"
# Build result
details = {}
if truncation.truncated:
details["truncation"] = truncation.to_dict()
if temp_file_path:
details["full_output_path"] = temp_file_path
# Build notice
start_line = truncation.total_lines - truncation.output_lines + 1
end_line = truncation.total_lines
if truncation.last_line_partial:
# Edge case: last line alone > 30KB
last_line = output.split('\n')[-1] if output else ""
last_line_size = format_size(len(last_line.encode('utf-8')))
output_text += f"\n\n[Showing last {format_size(truncation.output_bytes)} of line {end_line} (line is {last_line_size}). Full output: {temp_file_path}]"
elif truncation.truncated_by == "lines":
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines}. Full output: {temp_file_path}]"
else:
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Full output: {temp_file_path}]"
# Check exit code
if result.returncode != 0:
output_text += f"\n\nCommand exited with code {result.returncode}"
return ToolResult.fail({
"output": output_text,
"exit_code": result.returncode,
"details": details if details else None
})
return ToolResult.success({
"output": output_text,
"exit_code": result.returncode,
"details": details if details else None
})
except subprocess.TimeoutExpired:
return ToolResult.fail(f"Error: Command timed out after {timeout} seconds")
except Exception as e:
return ToolResult.fail(f"Error executing command: {str(e)}")
def _get_safety_warning(self, command: str) -> str:
"""
Get safety warning for potentially dangerous commands
Only warns about extremely dangerous system-level operations
:param command: Command to check
:return: Warning message if dangerous, empty string if safe
"""
cmd_lower = command.lower().strip()
# Only block extremely dangerous system operations
dangerous_patterns = [
# System shutdown/reboot
("shutdown", "This command will shut down the system"),
("reboot", "This command will reboot the system"),
("halt", "This command will halt the system"),
("poweroff", "This command will power off the system"),
# Critical system modifications
("rm -rf /", "This command will delete the entire filesystem"),
("rm -rf /*", "This command will delete the entire filesystem"),
("dd if=/dev/zero", "This command can destroy disk data"),
("mkfs", "This command will format a filesystem, destroying all data"),
("fdisk", "This command modifies disk partitions"),
# User/system management (only if targeting system users)
("userdel root", "This command will delete the root user"),
("passwd root", "This command will change the root password"),
]
for pattern, warning in dangerous_patterns:
if pattern in cmd_lower:
return warning
# Check for recursive deletion outside workspace
if "rm" in cmd_lower and "-rf" in cmd_lower:
# Allow deletion within current workspace
if not any(path in cmd_lower for path in ["./", self.cwd.lower()]):
# Check if targeting system directories
system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"]
if any(sysdir in cmd_lower for sysdir in system_dirs):
return "This command will recursively delete system directories"
return "" # No warning needed

View File

@@ -0,0 +1,18 @@
def copy(self):
"""
Special copy method for browser tool to avoid recreating browser instance.
:return: A new instance with shared browser reference but unique model
"""
new_tool = self.__class__()
# Copy essential attributes
new_tool.model = self.model
new_tool.context = getattr(self, 'context', None)
new_tool.config = getattr(self, 'config', None)
# Share the browser instance instead of creating a new one
if hasattr(self, 'browser'):
new_tool.browser = self.browser
return new_tool

View File

@@ -0,0 +1,58 @@
import math
from agent.tools.base_tool import BaseTool, ToolResult
class Calculator(BaseTool):
name: str = "calculator"
description: str = "A tool to perform basic mathematical calculations."
params: dict = {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). "
"Ensure your input is a valid Python expression, it will be evaluated directly."
}
},
"required": ["expression"]
}
config: dict = {}
def execute(self, args: dict) -> ToolResult:
try:
# Get the expression
expression = args["expression"]
# Create a safe local environment containing only basic math functions
safe_locals = {
"abs": abs,
"round": round,
"max": max,
"min": min,
"pow": pow,
"sqrt": math.sqrt,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"pi": math.pi,
"e": math.e,
"log": math.log,
"log10": math.log10,
"exp": math.exp,
"floor": math.floor,
"ceil": math.ceil
}
# Safely evaluate the expression
result = eval(expression, {"__builtins__": {}}, safe_locals)
return ToolResult.success({
"result": result,
"expression": expression
})
except Exception as e:
return ToolResult.success({
"error": str(e),
"expression": args.get("expression", "")
})

View File

@@ -0,0 +1,75 @@
import datetime
import time
from agent.tools.base_tool import BaseTool, ToolResult
class CurrentTime(BaseTool):
name: str = "time"
description: str = "A tool to get current date and time information."
params: dict = {
"type": "object",
"properties": {
"format": {
"type": "string",
"description": "Optional format for the time (e.g., 'iso', 'unix', 'human'). Default is 'human'."
},
"timezone": {
"type": "string",
"description": "Optional timezone specification (e.g., 'UTC', 'local'). Default is 'local'."
}
},
"required": []
}
config: dict = {}
def execute(self, args: dict) -> ToolResult:
try:
# Get the format and timezone parameters, with defaults
time_format = args.get("format", "human").lower()
timezone = args.get("timezone", "local").lower()
# Get current time
current_time = datetime.datetime.now()
# Handle timezone if specified
if timezone == "utc":
current_time = datetime.datetime.utcnow()
# Format the time according to the specified format
if time_format == "iso":
# ISO 8601 format
formatted_time = current_time.isoformat()
elif time_format == "unix":
# Unix timestamp (seconds since epoch)
formatted_time = time.time()
else:
# Human-readable format
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
# Prepare additional time components for the response
year = current_time.year
month = current_time.month
day = current_time.day
hour = current_time.hour
minute = current_time.minute
second = current_time.second
weekday = current_time.strftime("%A") # Full weekday name
result = {
"current_time": formatted_time,
"components": {
"year": year,
"month": month,
"day": day,
"hour": hour,
"minute": minute,
"second": second,
"weekday": weekday
},
"format": time_format,
"timezone": timezone
}
return ToolResult.success(result=result)
except Exception as e:
return ToolResult.fail(result=str(e))

View File

@@ -0,0 +1,3 @@
from .edit import Edit
__all__ = ['Edit']

173
agent/tools/edit/edit.py Normal file
View File

@@ -0,0 +1,173 @@
"""
Edit tool - Precise file editing
Edit files through exact text replacement
"""
import os
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.diff import (
strip_bom,
detect_line_ending,
normalize_to_lf,
restore_line_endings,
normalize_for_fuzzy_match,
fuzzy_find_text,
generate_diff_string
)
class Edit(BaseTool):
"""Tool for precise file editing"""
name: str = "edit"
description: str = "Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to edit (relative or absolute)"
},
"oldText": {
"type": "string",
"description": "Exact text to find and replace (must match exactly)"
},
"newText": {
"type": "string",
"description": "New text to replace the old text with"
}
},
"required": ["path", "oldText", "newText"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
self.memory_manager = self.config.get("memory_manager", None)
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file edit operation
:param args: Contains file path, old text and new text
:return: Operation result
"""
path = args.get("path", "").strip()
old_text = args.get("oldText", "")
new_text = args.get("newText", "")
if not path:
return ToolResult.fail("Error: path parameter is required")
# Resolve path
absolute_path = self._resolve_path(path)
# Check if file exists
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: File not found: {path}")
# Check if readable/writable
if not os.access(absolute_path, os.R_OK | os.W_OK):
return ToolResult.fail(f"Error: File is not readable/writable: {path}")
try:
# Read file
with open(absolute_path, 'r', encoding='utf-8') as f:
raw_content = f.read()
# Remove BOM (LLM won't include invisible BOM in oldText)
bom, content = strip_bom(raw_content)
# Detect original line ending
original_ending = detect_line_ending(content)
# Normalize to LF
normalized_content = normalize_to_lf(content)
normalized_old_text = normalize_to_lf(old_text)
normalized_new_text = normalize_to_lf(new_text)
# Use fuzzy matching to find old text (try exact match first, then fuzzy match)
match_result = fuzzy_find_text(normalized_content, normalized_old_text)
if not match_result.found:
return ToolResult.fail(
f"Error: Could not find the exact text in {path}. "
"The old text must match exactly including all whitespace and newlines."
)
# Calculate occurrence count (use fuzzy normalized content for consistency)
fuzzy_content = normalize_for_fuzzy_match(normalized_content)
fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
occurrences = fuzzy_content.count(fuzzy_old_text)
if occurrences > 1:
return ToolResult.fail(
f"Error: Found {occurrences} occurrences of the text in {path}. "
"The text must be unique. Please provide more context to make it unique."
)
# Execute replacement (use matched text position)
base_content = match_result.content_for_replacement
new_content = (
base_content[:match_result.index] +
normalized_new_text +
base_content[match_result.index + match_result.match_length:]
)
# Verify replacement actually changed content
if base_content == new_content:
return ToolResult.fail(
f"Error: No changes made to {path}. "
"The replacement produced identical content. "
"This might indicate an issue with special characters or the text not existing as expected."
)
# Restore original line endings
final_content = bom + restore_line_endings(new_content, original_ending)
# Write file
with open(absolute_path, 'w', encoding='utf-8') as f:
f.write(final_content)
# Generate diff
diff_result = generate_diff_string(base_content, new_content)
result = {
"message": f"Successfully replaced text in {path}",
"path": path,
"diff": diff_result['diff'],
"first_changed_line": diff_result['first_changed_line']
}
# Notify memory manager if file is in memory directory
if self.memory_manager and "memory/" in path:
try:
self.memory_manager.mark_dirty()
except Exception as e:
# Don't fail the edit if memory notification fails
pass
return ToolResult.success(result)
except UnicodeDecodeError:
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {path}")
except PermissionError:
return ToolResult.fail(f"Error: Permission denied accessing {path}")
except Exception as e:
return ToolResult.fail(f"Error editing file: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""
Resolve path to absolute path
:param path: Relative or absolute path
:return: Absolute path
"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -0,0 +1,3 @@
from .file_save import FileSave
__all__ = ['FileSave']

View File

@@ -0,0 +1,770 @@
import os
import time
import re
import json
from pathlib import Path
from typing import Dict, Any, Optional, Tuple
from agent.tools.base_tool import BaseTool, ToolResult, ToolStage
from agent.models import LLMRequest
from common.log import logger
class FileSave(BaseTool):
"""Tool for saving content to files in the workspace directory."""
name = "file_save"
description = "Save the agent's output to a file in the workspace directory. Content is automatically extracted from the agent's previous outputs."
# Set as post-process stage tool
stage = ToolStage.POST_PROCESS
params = {
"type": "object",
"properties": {
"file_name": {
"type": "string",
"description": "Optional. The name of the file to save. If not provided, a name will be generated based on the content."
},
"file_type": {
"type": "string",
"description": "Optional. The type/extension of the file (e.g., 'txt', 'md', 'py', 'java'). If not provided, it will be inferred from the content."
},
"extract_code": {
"type": "boolean",
"description": "Optional. If true, will attempt to extract code blocks from the content. Default is false."
}
},
"required": [] # No required fields, as everything can be extracted from context
}
def __init__(self):
self.context = None
self.config = {}
self.workspace_dir = Path("workspace")
def execute(self, params: Dict[str, Any]) -> ToolResult:
"""
Save content to a file in the workspace directory.
:param params: The parameters for the file output operation.
:return: Result of the operation.
"""
# Extract content from context
if not hasattr(self, 'context') or not self.context:
return ToolResult.fail("Error: No context available to extract content from.")
content = self._extract_content_from_context()
# If no content could be extracted, return error
if not content:
return ToolResult.fail("Error: Couldn't extract content from context.")
# Use model to determine file parameters
try:
task_dir = self._get_task_dir_from_context()
file_name, file_type, extract_code = self._get_file_params_from_model(content)
except Exception as e:
logger.error(f"Error determining file parameters: {str(e)}")
# Fall back to manual parameter extraction
task_dir = params.get("task_dir") or self._get_task_id_from_context() or f"task_{int(time.time())}"
file_name = params.get("file_name") or self._infer_file_name(content)
file_type = params.get("file_type") or self._infer_file_type(content)
extract_code = params.get("extract_code", False)
# Get team_name from context
team_name = self._get_team_name_from_context() or "default_team"
# Create directory structure
task_dir_path = self.workspace_dir / team_name / task_dir
task_dir_path.mkdir(parents=True, exist_ok=True)
if extract_code:
# Save the complete content as markdown
md_file_name = f"{file_name}.md"
md_file_path = task_dir_path / md_file_name
# Write content to file
with open(md_file_path, 'w', encoding='utf-8') as f:
f.write(content)
return self._handle_multiple_code_blocks(content)
# Ensure file_name has the correct extension
if file_type and not file_name.endswith(f".{file_type}"):
file_name = f"{file_name}.{file_type}"
# Create the full file path
file_path = task_dir_path / file_name
# Get absolute path for storage in team_context
abs_file_path = file_path.absolute()
try:
# Write content to file
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
# Update the current agent's final_answer to include file information
if hasattr(self.context, 'team_context'):
# Store with absolute path in team_context
self.context.team_context.agent_outputs[-1].output += f"\n\nSaved file: {abs_file_path}"
return ToolResult.success({
"status": "success",
"file_path": str(file_path) # Return relative path in result
})
except Exception as e:
return ToolResult.fail(f"Error saving file: {str(e)}")
def _handle_multiple_code_blocks(self, content: str) -> ToolResult:
"""
Handle content with multiple code blocks, extracting and saving each as a separate file.
:param content: The content containing multiple code blocks
:return: Result of the operation
"""
# Extract code blocks with context (including potential file name information)
code_blocks_with_context = self._extract_code_blocks_with_context(content)
if not code_blocks_with_context:
return ToolResult.fail("No code blocks found in the content.")
# Get task directory and team name
task_dir = self._get_task_dir_from_context() or f"task_{int(time.time())}"
team_name = self._get_team_name_from_context() or "default_team"
# Create directory structure
task_dir_path = self.workspace_dir / team_name / task_dir
task_dir_path.mkdir(parents=True, exist_ok=True)
saved_files = []
for block_with_context in code_blocks_with_context:
try:
# Use model to determine file name for this code block
block_file_name, block_file_type = self._get_filename_for_code_block(block_with_context)
# Clean the code block (remove md code markers)
clean_code = self._clean_code_block(block_with_context)
# Ensure file_name has the correct extension
if block_file_type and not block_file_name.endswith(f".{block_file_type}"):
block_file_name = f"{block_file_name}.{block_file_type}"
# Create the full file path (no subdirectories)
file_path = task_dir_path / block_file_name
# Get absolute path for storage in team_context
abs_file_path = file_path.absolute()
# Write content to file
with open(file_path, 'w', encoding='utf-8') as f:
f.write(clean_code)
saved_files.append({
"file_path": str(file_path),
"abs_file_path": str(abs_file_path), # Store absolute path for internal use
"file_name": block_file_name,
"size": len(clean_code),
"status": "success",
"type": "code"
})
except Exception as e:
logger.error(f"Error saving code block: {str(e)}")
# Continue with the next block even if this one fails
if not saved_files:
return ToolResult.fail("Failed to save any code blocks.")
# Update the current agent's final_answer to include files information
if hasattr(self, 'context') and self.context:
# If the agent has a final_answer attribute, append the files info to it
if hasattr(self.context, 'team_context'):
# Use relative paths for display
display_info = f"\n\nSaved files to {task_dir_path}:\n" + "\n".join(
[f"- {f['file_path']}" for f in saved_files])
# Check if we need to append the info
if not self.context.team_context.agent_outputs[-1].output.endswith(display_info):
# Store with absolute paths in team_context
abs_info = f"\n\nSaved files to {task_dir_path.absolute()}:\n" + "\n".join(
[f"- {f['abs_file_path']}" for f in saved_files])
self.context.team_context.agent_outputs[-1].output += abs_info
result = {
"status": "success",
"files": [{"file_path": f["file_path"]} for f in saved_files]
}
return ToolResult.success(result)
def _extract_code_blocks_with_context(self, content: str) -> list:
"""
Extract code blocks from content, including context lines before the block.
:param content: The content to extract code blocks from
:return: List of code blocks with context
"""
# Check if content starts with <!DOCTYPE or <html - likely a full HTML file
if content.strip().startswith(("<!DOCTYPE", "<html", "<?xml")):
return [content] # Return the entire content as a single block
# Split content into lines
lines = content.split('\n')
blocks = []
in_code_block = False
current_block = []
context_lines = []
# Check if there are any code block markers in the content
if not re.search(r'```\w+', content):
# If no code block markers and content looks like code, return the entire content
if self._is_likely_code(content):
return [content]
for line in lines:
if line.strip().startswith('```'):
if in_code_block:
# End of code block
current_block.append(line)
# Only add blocks that have a language specified
block_content = '\n'.join(current_block)
if re.search(r'```\w+', current_block[0]):
# Combine context with code block
blocks.append('\n'.join(context_lines + current_block))
current_block = []
context_lines = []
in_code_block = False
else:
# Start of code block - check if it has a language specified
if re.search(r'```\w+', line) and not re.search(r'```language=\s*$', line):
# Start of code block with language
in_code_block = True
current_block = [line]
# Keep only the last few context lines
context_lines = context_lines[-5:] if context_lines else []
elif in_code_block:
current_block.append(line)
else:
# Store context lines when not in a code block
context_lines.append(line)
return blocks
def _get_filename_for_code_block(self, block_with_context: str) -> Tuple[str, str]:
"""
Determine the file name for a code block.
:param block_with_context: The code block with context lines
:return: Tuple of (file_name, file_type)
"""
# Define common code file extensions
COMMON_CODE_EXTENSIONS = {
'py', 'js', 'java', 'c', 'cpp', 'h', 'hpp', 'cs', 'go', 'rb', 'php',
'html', 'css', 'ts', 'jsx', 'tsx', 'vue', 'sh', 'sql', 'json', 'xml',
'yaml', 'yml', 'md', 'rs', 'swift', 'kt', 'scala', 'pl', 'r', 'lua'
}
# Split the block into lines to examine only the context around code block markers
lines = block_with_context.split('\n')
# Find the code block start marker line index
start_marker_idx = -1
for i, line in enumerate(lines):
if line.strip().startswith('```') and not line.strip() == '```':
start_marker_idx = i
break
if start_marker_idx == -1:
# No code block marker found
return "", ""
# Extract the language from the code block marker
code_marker = lines[start_marker_idx].strip()
language = ""
if len(code_marker) > 3:
language = code_marker[3:].strip().split('=')[0].strip()
# Define the context range (5 lines before and 2 after the marker)
context_start = max(0, start_marker_idx - 5)
context_end = min(len(lines), start_marker_idx + 3)
# Extract only the relevant context lines
context_lines = lines[context_start:context_end]
# First, check for explicit file headers like "## filename.ext"
for line in context_lines:
# Match patterns like "## filename.ext" or "# filename.ext"
header_match = re.search(r'^\s*#{1,6}\s+([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)\s*$', line)
if header_match:
file_name = header_match.group(1)
file_type = os.path.splitext(file_name)[1].lstrip('.')
if file_type in COMMON_CODE_EXTENSIONS:
return os.path.splitext(file_name)[0], file_type
# Simple patterns to match explicit file names in the context
file_patterns = [
# Match explicit file names in headers or text
r'(?:file|filename)[:=\s]+[\'"]?([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)[\'"]?',
# Match language=filename.ext in code markers
r'language=([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)',
# Match standalone filenames with extensions
r'\b([a-zA-Z0-9_-]+\.(py|js|java|c|cpp|h|hpp|cs|go|rb|php|html|css|ts|jsx|tsx|vue|sh|sql|json|xml|yaml|yml|md|rs|swift|kt|scala|pl|r|lua))\b',
# Match file paths in comments
r'#\s*([a-zA-Z0-9_/-]+\.[a-zA-Z0-9]+)'
]
# Check each context line for file name patterns
for line in context_lines:
line = line.strip()
for pattern in file_patterns:
matches = re.findall(pattern, line)
if matches:
for match in matches:
if isinstance(match, tuple):
# If the match is a tuple (filename, extension)
file_name = match[0]
file_type = match[1]
# Verify it's not a code reference like Direction.DOWN
if not any(keyword in file_name for keyword in ['class.', 'enum.', 'import.']):
return os.path.splitext(file_name)[0], file_type
else:
# If the match is a string (full filename)
file_name = match
file_type = os.path.splitext(file_name)[1].lstrip('.')
# Verify it's not a code reference
if file_type in COMMON_CODE_EXTENSIONS and not any(
keyword in file_name for keyword in ['class.', 'enum.', 'import.']):
return os.path.splitext(file_name)[0], file_type
# If no explicit file name found, use LLM to infer from code content
# Extract the code content
code_content = block_with_context
# Get the first 20 lines of code for LLM analysis
code_lines = code_content.split('\n')
code_preview = '\n'.join(code_lines[:20])
# Get the model to use
model_to_use = None
if hasattr(self, 'context') and self.context:
if hasattr(self.context, 'model') and self.context.model:
model_to_use = self.context.model
elif hasattr(self.context, 'team_context') and self.context.team_context:
if hasattr(self.context.team_context, 'model') and self.context.team_context.model:
model_to_use = self.context.team_context.model
# If no model is available in context, use the tool's model
if not model_to_use and hasattr(self, 'model') and self.model:
model_to_use = self.model
if model_to_use:
# Prepare a prompt for the model
prompt = f"""Analyze the following code and determine the most appropriate file name and file type/extension.
The file name should be descriptive but concise, using snake_case (lowercase with underscores).
The file type should be a standard file extension (e.g., py, js, html, css, java).
Code preview (first 20 lines):
{code_preview}
Return your answer in JSON format with these fields:
- file_name: The suggested file name (without extension)
- file_type: The suggested file extension
JSON response:"""
# Create a request to the model
request = LLMRequest(
messages=[{"role": "user", "content": prompt}],
temperature=0,
json_format=True
)
try:
response = model_to_use.call(request)
if not response.is_error:
# Clean the JSON response
json_content = self._clean_json_response(response.data["choices"][0]["message"]["content"])
result = json.loads(json_content)
file_name = result.get("file_name", "")
file_type = result.get("file_type", "")
if file_name and file_type:
return file_name, file_type
except Exception as e:
logger.error(f"Error using model to determine file name: {str(e)}")
# If we still don't have a file name, use the language as file type
if language and language in COMMON_CODE_EXTENSIONS:
timestamp = int(time.time())
return f"code_{timestamp}", language
# If all else fails, return empty strings
return "", ""
def _clean_json_response(self, text: str) -> str:
"""
Clean JSON response from LLM by removing markdown code block markers.
:param text: The text containing JSON possibly wrapped in markdown code blocks
:return: Clean JSON string
"""
# Remove markdown code block markers if present
if text.startswith("```json"):
text = text[7:]
elif text.startswith("```"):
# Find the first newline to skip the language identifier line
first_newline = text.find('\n')
if first_newline != -1:
text = text[first_newline + 1:]
if text.endswith("```"):
text = text[:-3]
return text.strip()
def _clean_code_block(self, block_with_context: str) -> str:
"""
Clean a code block by removing markdown code markers and context lines.
:param block_with_context: Code block with context lines
:return: Clean code ready for execution
"""
# Check if this is a full HTML or XML document
if block_with_context.strip().startswith(("<!DOCTYPE", "<html", "<?xml")):
return block_with_context
# Find the code block
code_block_match = re.search(r'```(?:\w+)?(?:[:=][^\n]+)?\n([\s\S]*?)\n```', block_with_context)
if code_block_match:
return code_block_match.group(1)
# If no match found, try to extract anything between ``` markers
lines = block_with_context.split('\n')
start_idx = None
end_idx = None
for i, line in enumerate(lines):
if line.strip().startswith('```'):
if start_idx is None:
start_idx = i
else:
end_idx = i
break
if start_idx is not None and end_idx is not None:
# Extract the code between the markers, excluding the markers themselves
code_lines = lines[start_idx + 1:end_idx]
return '\n'.join(code_lines)
# If all else fails, return the original content
return block_with_context
def _get_file_params_from_model(self, content, model=None):
"""
Use LLM to determine if the content is code and suggest appropriate file parameters.
Args:
content: The content to analyze
model: Optional model to use for the analysis
Returns:
tuple: (file_name, file_type, extract_code) for backward compatibility
"""
if model is None:
model = self.model
if not model:
# Default fallback if no model is available
return "output", "txt", False
prompt = f"""
Analyze the following content and determine:
1. Is this primarily code implementation (where most of the content consists of code blocks)?
2. What would be an appropriate filename and file extension?
Content to analyze: ```
{content[:500]} # Only show first 500 chars to avoid token limits ```
{"..." if len(content) > 500 else ""}
Respond in JSON format only with the following structure:
{{
"is_code": true/false, # Whether this is primarily code implementation
"filename": "suggested_filename", # Don't include extension, english words
"extension": "appropriate_extension" # Don't include the dot, e.g., "md", "py", "js"
}}
"""
try:
# Create a request to the model
request = LLMRequest(
messages=[{"role": "user", "content": prompt}],
temperature=0.1,
json_format=True
)
# Call the model using the standard interface
response = model.call(request)
if response.is_error:
logger.warning(f"Error from model: {response.error_message}")
raise Exception(f"Model error: {response.error_message}")
# Extract JSON from response
result = response.data["choices"][0]["message"]["content"]
# Clean the JSON response
result = self._clean_json_response(result)
# Parse the JSON
params = json.loads(result)
# For backward compatibility, return tuple format
file_name = params.get("filename", "output")
# Remove dot from extension if present
file_type = params.get("extension", "md").lstrip(".")
extract_code = params.get("is_code", False)
return file_name, file_type, extract_code
except Exception as e:
logger.warning(f"Error getting file parameters from model: {e}")
# Default fallback
return "output", "md", False
def _get_team_name_from_context(self) -> Optional[str]:
"""
Get team name from the agent's context.
:return: Team name or None if not found
"""
if hasattr(self, 'context') and self.context:
# Try to get team name from team_context
if hasattr(self.context, 'team_context') and self.context.team_context:
return self.context.team_context.name
# Try direct team_name attribute
if hasattr(self.context, 'name'):
return self.context.name
return None
def _get_task_id_from_context(self) -> Optional[str]:
"""
Get task ID from the agent's context.
:return: Task ID or None if not found
"""
if hasattr(self, 'context') and self.context:
# Try to get task ID from task object
if hasattr(self.context, 'task') and self.context.task:
return self.context.task.id
# Try team_context's task
if hasattr(self.context, 'team_context') and self.context.team_context:
if hasattr(self.context.team_context, 'task') and self.context.team_context.task:
return self.context.team_context.task.id
return None
def _get_task_dir_from_context(self) -> Optional[str]:
"""
Get task directory name from the team context.
:return: Task directory name or None if not found
"""
if hasattr(self, 'context') and self.context:
# Try to get from team_context
if hasattr(self.context, 'team_context') and self.context.team_context:
if hasattr(self.context.team_context, 'task_short_name') and self.context.team_context.task_short_name:
return self.context.team_context.task_short_name
# Fall back to task ID if available
return self._get_task_id_from_context()
def _extract_content_from_context(self) -> str:
"""
Extract content from the agent's context.
:return: Extracted content
"""
# Check if we have access to the agent's context
if not hasattr(self, 'context') or not self.context:
return ""
# Try to get the most recent final answer from the agent
if hasattr(self.context, 'final_answer') and self.context.final_answer:
return self.context.final_answer
# Try to get the most recent final answer from team context
if hasattr(self.context, 'team_context') and self.context.team_context:
if hasattr(self.context.team_context, 'agent_outputs') and self.context.team_context.agent_outputs:
latest_output = self.context.team_context.agent_outputs[-1].output
return latest_output
# If we have action history, try to get the most recent final answer
if hasattr(self.context, 'action_history') and self.context.action_history:
for action in reversed(self.context.action_history):
if "final_answer" in action and action["final_answer"]:
return action["final_answer"]
return ""
def _extract_code_blocks(self, content: str) -> str:
"""
Extract code blocks from markdown content.
:param content: The content to extract code blocks from
:return: Extracted code blocks
"""
# Pattern to match markdown code blocks
code_block_pattern = r'```(?:\w+)?\n([\s\S]*?)\n```'
# Find all code blocks
code_blocks = re.findall(code_block_pattern, content)
if code_blocks:
# Join all code blocks with newlines
return '\n\n'.join(code_blocks)
return content # Return original content if no code blocks found
def _infer_file_name(self, content: str) -> str:
"""
Infer a file name from the content.
:param content: The content to analyze.
:return: A suggested file name.
"""
# Check for title patterns in markdown
title_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
if title_match:
# Convert title to a valid filename
title = title_match.group(1).strip()
return self._sanitize_filename(title)
# Check for class/function definitions in code
code_match = re.search(r'(class|def|function)\s+(\w+)', content)
if code_match:
return self._sanitize_filename(code_match.group(2))
# Default name based on content type
if self._is_likely_code(content):
return "code"
elif self._is_likely_markdown(content):
return "document"
elif self._is_likely_json(content):
return "data"
else:
return "output"
def _infer_file_type(self, content: str) -> str:
"""
Infer the file type/extension from the content.
:param content: The content to analyze.
:return: A suggested file extension.
"""
# Check for common programming language patterns
if re.search(r'(import\s+[a-zA-Z0-9_]+|from\s+[a-zA-Z0-9_\.]+\s+import)', content):
return "py" # Python
elif re.search(r'(public\s+class|private\s+class|protected\s+class)', content):
return "java" # Java
elif re.search(r'(function\s+\w+\s*\(|const\s+\w+\s*=|let\s+\w+\s*=|var\s+\w+\s*=)', content):
return "js" # JavaScript
elif re.search(r'(<html|<body|<div|<p>)', content):
return "html" # HTML
elif re.search(r'(#include\s+<\w+\.h>|int\s+main\s*\()', content):
return "cpp" # C/C++
# Check for markdown
if self._is_likely_markdown(content):
return "md"
# Check for JSON
if self._is_likely_json(content):
return "json"
# Default to text
return "txt"
def _is_likely_code(self, content: str) -> bool:
"""Check if the content is likely code."""
# First check for common HTML/XML patterns
if content.strip().startswith(("<!DOCTYPE", "<html", "<?xml", "<head", "<body")):
return True
code_patterns = [
r'(class|def|function|import|from|public|private|protected|#include)',
r'(\{\s*\n|\}\s*\n|\[\s*\n|\]\s*\n)',
r'(if\s*\(|for\s*\(|while\s*\()',
r'(<\w+>.*?</\w+>)', # HTML/XML tags
r'(var|let|const)\s+\w+\s*=', # JavaScript variable declarations
r'#\s*\w+', # CSS ID selectors or Python comments
r'\.\w+\s*\{', # CSS class selectors
r'@media|@import|@font-face' # CSS at-rules
]
return any(re.search(pattern, content) for pattern in code_patterns)
def _is_likely_markdown(self, content: str) -> bool:
"""Check if the content is likely markdown."""
md_patterns = [
r'^#\s+.+$', # Headers
r'^\*\s+.+$', # Unordered lists
r'^\d+\.\s+.+$', # Ordered lists
r'\[.+\]\(.+\)', # Links
r'!\[.+\]\(.+\)' # Images
]
return any(re.search(pattern, content, re.MULTILINE) for pattern in md_patterns)
def _is_likely_json(self, content: str) -> bool:
"""Check if the content is likely JSON."""
try:
content = content.strip()
if (content.startswith('{') and content.endswith('}')) or (
content.startswith('[') and content.endswith(']')):
json.loads(content)
return True
except:
pass
return False
def _sanitize_filename(self, name: str) -> str:
"""
Sanitize a string to be used as a filename.
:param name: The string to sanitize.
:return: A sanitized filename.
"""
# Replace spaces with underscores
name = name.replace(' ', '_')
# Remove invalid characters
name = re.sub(r'[^\w\-\.]', '', name)
# Limit length
if len(name) > 50:
name = name[:50]
return name.lower()
def _process_file_path(self, file_path: str) -> Tuple[str, str]:
"""
Process a file path to extract the file name and type, and create directories if needed.
:param file_path: The file path to process
:return: Tuple of (file_name, file_type)
"""
# Get the file name and extension
file_name = os.path.basename(file_path)
file_type = os.path.splitext(file_name)[1].lstrip('.')
return os.path.splitext(file_name)[0], file_type

View File

@@ -0,0 +1,3 @@
from .find import Find
__all__ = ['Find']

177
agent/tools/find/find.py Normal file
View File

@@ -0,0 +1,177 @@
"""
Find tool - Search for files by glob pattern
"""
import os
import glob as glob_module
from typing import Dict, Any, List
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
DEFAULT_LIMIT = 1000
class Find(BaseTool):
"""Tool for finding files by pattern"""
name: str = "find"
description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
params: dict = {
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'"
},
"path": {
"type": "string",
"description": "Directory to search in (default: current directory)"
},
"limit": {
"type": "integer",
"description": f"Maximum number of results (default: {DEFAULT_LIMIT})"
}
},
"required": ["pattern"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file search
:param args: Search parameters
:return: Search results or error
"""
pattern = args.get("pattern", "").strip()
search_path = args.get("path", ".").strip()
limit = args.get("limit", DEFAULT_LIMIT)
if not pattern:
return ToolResult.fail("Error: pattern parameter is required")
# Resolve search path
absolute_path = self._resolve_path(search_path)
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: Path not found: {search_path}")
if not os.path.isdir(absolute_path):
return ToolResult.fail(f"Error: Not a directory: {search_path}")
try:
# Load .gitignore patterns
ignore_patterns = self._load_gitignore(absolute_path)
# Search for files
results = []
search_pattern = os.path.join(absolute_path, pattern)
# Use glob with recursive support
for file_path in glob_module.glob(search_pattern, recursive=True):
# Skip if matches ignore patterns
if self._should_ignore(file_path, absolute_path, ignore_patterns):
continue
# Get relative path
relative_path = os.path.relpath(file_path, absolute_path)
# Add trailing slash for directories
if os.path.isdir(file_path):
relative_path += '/'
results.append(relative_path)
if len(results) >= limit:
break
if not results:
return ToolResult.success({"message": "No files found matching pattern", "files": []})
# Sort results
results.sort()
# Format output
raw_output = '\n'.join(results)
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
output = truncation.content
details = {}
notices = []
result_limit_reached = len(results) >= limit
if result_limit_reached:
notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern")
details["result_limit_reached"] = limit
if truncation.truncated:
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
details["truncation"] = truncation.to_dict()
if notices:
output += f"\n\n[{'. '.join(notices)}]"
return ToolResult.success({
"output": output,
"file_count": len(results),
"details": details if details else None
})
except Exception as e:
return ToolResult.fail(f"Error executing find: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))
def _load_gitignore(self, directory: str) -> List[str]:
"""Load .gitignore patterns from directory"""
patterns = []
gitignore_path = os.path.join(directory, '.gitignore')
if os.path.exists(gitignore_path):
try:
with open(gitignore_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'):
patterns.append(line)
except:
pass
# Add common ignore patterns
patterns.extend([
'.git',
'__pycache__',
'*.pyc',
'node_modules',
'.DS_Store'
])
return patterns
def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool:
"""Check if file should be ignored based on patterns"""
relative_path = os.path.relpath(file_path, base_path)
for pattern in patterns:
# Simple pattern matching
if pattern in relative_path:
return True
# Check if it's a directory pattern
if pattern.endswith('/'):
if relative_path.startswith(pattern.rstrip('/')):
return True
return False

View File

@@ -0,0 +1,3 @@
from .grep import Grep
__all__ = ['Grep']

248
agent/tools/grep/grep.py Normal file
View File

@@ -0,0 +1,248 @@
"""
Grep tool - Search file contents for patterns
Uses ripgrep (rg) for fast searching
"""
import os
import re
import subprocess
import json
from typing import Dict, Any, List, Optional
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import (
truncate_head, truncate_line, format_size,
DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH
)
DEFAULT_LIMIT = 100
class Grep(BaseTool):
"""Tool for searching file contents"""
name: str = "grep"
description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars."
params: dict = {
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Search pattern (regex or literal string)"
},
"path": {
"type": "string",
"description": "Directory or file to search (default: current directory)"
},
"glob": {
"type": "string",
"description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'"
},
"ignoreCase": {
"type": "boolean",
"description": "Case-insensitive search (default: false)"
},
"literal": {
"type": "boolean",
"description": "Treat pattern as literal string instead of regex (default: false)"
},
"context": {
"type": "integer",
"description": "Number of lines to show before and after each match (default: 0)"
},
"limit": {
"type": "integer",
"description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})"
}
},
"required": ["pattern"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
self.rg_path = self._find_ripgrep()
def _find_ripgrep(self) -> Optional[str]:
"""Find ripgrep executable"""
try:
result = subprocess.run(['which', 'rg'], capture_output=True, text=True)
if result.returncode == 0:
return result.stdout.strip()
except:
pass
return None
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute grep search
:param args: Search parameters
:return: Search results or error
"""
if not self.rg_path:
return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.")
pattern = args.get("pattern", "").strip()
search_path = args.get("path", ".").strip()
glob = args.get("glob")
ignore_case = args.get("ignoreCase", False)
literal = args.get("literal", False)
context = args.get("context", 0)
limit = args.get("limit", DEFAULT_LIMIT)
if not pattern:
return ToolResult.fail("Error: pattern parameter is required")
# Resolve search path
absolute_path = self._resolve_path(search_path)
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: Path not found: {search_path}")
# Build ripgrep command
cmd = [
self.rg_path,
'--json',
'--line-number',
'--color=never',
'--hidden'
]
if ignore_case:
cmd.append('--ignore-case')
if literal:
cmd.append('--fixed-strings')
if glob:
cmd.extend(['--glob', glob])
cmd.extend([pattern, absolute_path])
try:
# Execute ripgrep
result = subprocess.run(
cmd,
cwd=self.cwd,
capture_output=True,
text=True,
timeout=30
)
# Parse JSON output
matches = []
match_count = 0
for line in result.stdout.splitlines():
if not line.strip():
continue
try:
event = json.loads(line)
if event.get('type') == 'match':
data = event.get('data', {})
file_path = data.get('path', {}).get('text')
line_number = data.get('line_number')
if file_path and line_number:
matches.append({
'file': file_path,
'line': line_number
})
match_count += 1
if match_count >= limit:
break
except json.JSONDecodeError:
continue
if match_count == 0:
return ToolResult.success({"message": "No matches found", "matches": []})
# Format output with context
output_lines = []
lines_truncated = False
is_directory = os.path.isdir(absolute_path)
for match in matches:
file_path = match['file']
line_number = match['line']
# Format file path
if is_directory:
relative_path = os.path.relpath(file_path, absolute_path)
else:
relative_path = os.path.basename(file_path)
# Read file and get context
try:
with open(file_path, 'r', encoding='utf-8') as f:
file_lines = f.read().split('\n')
# Calculate context range
start = max(0, line_number - 1 - context) if context > 0 else line_number - 1
end = min(len(file_lines), line_number + context) if context > 0 else line_number
# Format lines with context
for i in range(start, end):
line_text = file_lines[i].replace('\r', '')
# Truncate long lines
truncated_text, was_truncated = truncate_line(line_text)
if was_truncated:
lines_truncated = True
# Format output
current_line = i + 1
if current_line == line_number:
output_lines.append(f"{relative_path}:{current_line}: {truncated_text}")
else:
output_lines.append(f"{relative_path}-{current_line}- {truncated_text}")
except Exception:
output_lines.append(f"{relative_path}:{line_number}: (unable to read file)")
# Apply byte truncation
raw_output = '\n'.join(output_lines)
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
output = truncation.content
details = {}
notices = []
if match_count >= limit:
notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern")
details["match_limit_reached"] = limit
if truncation.truncated:
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
details["truncation"] = truncation.to_dict()
if lines_truncated:
notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines")
details["lines_truncated"] = True
if notices:
output += f"\n\n[{'. '.join(notices)}]"
return ToolResult.success({
"output": output,
"match_count": match_count,
"details": details if details else None
})
except subprocess.TimeoutExpired:
return ToolResult.fail("Error: Search timed out after 30 seconds")
except Exception as e:
return ToolResult.fail(f"Error executing grep: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -0,0 +1,3 @@
from .ls import Ls
__all__ = ['Ls']

125
agent/tools/ls/ls.py Normal file
View File

@@ -0,0 +1,125 @@
"""
Ls tool - List directory contents
"""
import os
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
DEFAULT_LIMIT = 500
class Ls(BaseTool):
"""Tool for listing directory contents"""
name: str = "ls"
description: str = f"List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to {DEFAULT_LIMIT} entries or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Directory to list (default: current directory)"
},
"limit": {
"type": "integer",
"description": f"Maximum number of entries to return (default: {DEFAULT_LIMIT})"
}
},
"required": []
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute directory listing
:param args: Listing parameters
:return: Directory contents or error
"""
path = args.get("path", ".").strip()
limit = args.get("limit", DEFAULT_LIMIT)
# Resolve path
absolute_path = self._resolve_path(path)
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: Path not found: {path}")
if not os.path.isdir(absolute_path):
return ToolResult.fail(f"Error: Not a directory: {path}")
try:
# Read directory entries
entries = os.listdir(absolute_path)
# Sort alphabetically (case-insensitive)
entries.sort(key=lambda x: x.lower())
# Format entries with directory indicators
results = []
entry_limit_reached = False
for entry in entries:
if len(results) >= limit:
entry_limit_reached = True
break
full_path = os.path.join(absolute_path, entry)
try:
if os.path.isdir(full_path):
results.append(entry + '/')
else:
results.append(entry)
except:
# Skip entries we can't stat
continue
if not results:
return ToolResult.success({"message": "(empty directory)", "entries": []})
# Format output
raw_output = '\n'.join(results)
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
output = truncation.content
details = {}
notices = []
if entry_limit_reached:
notices.append(f"{limit} entries limit reached. Use limit={limit * 2} for more")
details["entry_limit_reached"] = limit
if truncation.truncated:
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
details["truncation"] = truncation.to_dict()
if notices:
output += f"\n\n[{'. '.join(notices)}]"
return ToolResult.success({
"output": output,
"entry_count": len(results),
"details": details if details else None
})
except PermissionError:
return ToolResult.fail(f"Error: Permission denied reading directory: {path}")
except Exception as e:
return ToolResult.fail(f"Error listing directory: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -0,0 +1,10 @@
"""
Memory tools for Agent
Provides memory_search and memory_get tools
"""
from agent.tools.memory.memory_search import MemorySearchTool
from agent.tools.memory.memory_get import MemoryGetTool
__all__ = ['MemorySearchTool', 'MemoryGetTool']

View File

@@ -0,0 +1,112 @@
"""
Memory get tool
Allows agents to read specific sections from memory files
"""
from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool
class MemoryGetTool(BaseTool):
"""Tool for reading memory file contents"""
name: str = "memory_get"
description: str = (
"Read specific content from memory files. "
"Use this to get full context from a memory file or specific line range."
)
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')"
},
"start_line": {
"type": "integer",
"description": "Starting line number (optional, default: 1)",
"default": 1
},
"num_lines": {
"type": "integer",
"description": "Number of lines to read (optional, reads all if not specified)"
}
},
"required": ["path"]
}
def __init__(self, memory_manager):
"""
Initialize memory get tool
Args:
memory_manager: MemoryManager instance
"""
super().__init__()
self.memory_manager = memory_manager
def execute(self, args: dict):
"""
Execute memory file read
Args:
args: Dictionary with path, start_line, num_lines
Returns:
ToolResult with file content
"""
from agent.tools.base_tool import ToolResult
path = args.get("path")
start_line = args.get("start_line", 1)
num_lines = args.get("num_lines")
if not path:
return ToolResult.fail("Error: path parameter is required")
try:
workspace_dir = self.memory_manager.config.get_workspace()
# Auto-prepend memory/ if not present and not absolute path
if not path.startswith('memory/') and not path.startswith('/'):
path = f'memory/{path}'
file_path = workspace_dir / path
if not file_path.exists():
return ToolResult.fail(f"Error: File not found: {path}")
content = file_path.read_text()
lines = content.split('\n')
# Handle line range
if start_line < 1:
start_line = 1
start_idx = start_line - 1
if num_lines:
end_idx = start_idx + num_lines
selected_lines = lines[start_idx:end_idx]
else:
selected_lines = lines[start_idx:]
result = '\n'.join(selected_lines)
# Add metadata
total_lines = len(lines)
shown_lines = len(selected_lines)
output = [
f"File: {path}",
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
"",
result
]
return ToolResult.success('\n'.join(output))
except Exception as e:
return ToolResult.fail(f"Error reading memory file: {str(e)}")

View File

@@ -0,0 +1,102 @@
"""
Memory search tool
Allows agents to search their memory using semantic and keyword search
"""
from typing import Dict, Any, Optional
from agent.tools.base_tool import BaseTool
class MemorySearchTool(BaseTool):
"""Tool for searching agent memory"""
name: str = "memory_search"
description: str = (
"Search agent's long-term memory using semantic and keyword search. "
"Use this to recall past conversations, preferences, and knowledge."
)
params: dict = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query (can be natural language question or keywords)"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results to return (default: 10)",
"default": 10
},
"min_score": {
"type": "number",
"description": "Minimum relevance score (0-1, default: 0.1)",
"default": 0.1
}
},
"required": ["query"]
}
def __init__(self, memory_manager, user_id: Optional[str] = None):
"""
Initialize memory search tool
Args:
memory_manager: MemoryManager instance
user_id: Optional user ID for scoped search
"""
super().__init__()
self.memory_manager = memory_manager
self.user_id = user_id
def execute(self, args: dict):
"""
Execute memory search
Args:
args: Dictionary with query, max_results, min_score
Returns:
ToolResult with formatted search results
"""
from agent.tools.base_tool import ToolResult
import asyncio
query = args.get("query")
max_results = args.get("max_results", 10)
min_score = args.get("min_score", 0.1)
if not query:
return ToolResult.fail("Error: query parameter is required")
try:
# Run async search in sync context
results = asyncio.run(self.memory_manager.search(
query=query,
user_id=self.user_id,
max_results=max_results,
min_score=min_score,
include_shared=True
))
if not results:
# Return clear message that no memories exist yet
# This prevents infinite retry loops
return ToolResult.success(
f"No memories found for '{query}'. "
f"This is normal if no memories have been stored yet. "
f"You can store new memories by writing to MEMORY.md or memory/YYYY-MM-DD.md files."
)
# Format results
output = [f"Found {len(results)} relevant memories:\n"]
for i, result in enumerate(results, 1):
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
output.append(f" Score: {result.score:.3f}")
output.append(f" Snippet: {result.snippet}")
return ToolResult.success("\n".join(output))
except Exception as e:
return ToolResult.fail(f"Error searching memory: {str(e)}")

View File

@@ -0,0 +1,3 @@
from .read import Read
__all__ = ['Read']

336
agent/tools/read/read.py Normal file
View File

@@ -0,0 +1,336 @@
"""
Read tool - Read file contents
Supports text files, images (jpg, png, gif, webp), and PDF files
"""
import os
from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
class Read(BaseTool):
"""Tool for reading file contents"""
name: str = "read"
description: str = f"Read the contents of a file. Supports text files, PDF files, and images (jpg, png, gif, webp). For text files, output is truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Use offset/limit for large files."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to read (relative or absolute)"
},
"offset": {
"type": "integer",
"description": "Line number to start reading from (1-indexed, optional)"
},
"limit": {
"type": "integer",
"description": "Maximum number of lines to read (optional)"
}
},
"required": ["path"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
# Supported image formats
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp'}
# Supported PDF format
self.pdf_extensions = {'.pdf'}
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file read operation
:param args: Contains file path and optional offset/limit parameters
:return: File content or error message
"""
path = args.get("path", "").strip()
offset = args.get("offset")
limit = args.get("limit")
if not path:
return ToolResult.fail("Error: path parameter is required")
# Resolve path
absolute_path = self._resolve_path(path)
# Check if file exists
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: File not found: {path}")
# Check if readable
if not os.access(absolute_path, os.R_OK):
return ToolResult.fail(f"Error: File is not readable: {path}")
# Check file type
file_ext = Path(absolute_path).suffix.lower()
# Check if image
if file_ext in self.image_extensions:
return self._read_image(absolute_path, file_ext)
# Check if PDF
if file_ext in self.pdf_extensions:
return self._read_pdf(absolute_path, path, offset, limit)
# Read text file
return self._read_text(absolute_path, path, offset, limit)
def _resolve_path(self, path: str) -> str:
"""
Resolve path to absolute path
:param path: Relative or absolute path
:return: Absolute path
"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))
def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult:
"""
Read image file
:param absolute_path: Absolute path to the image file
:param file_ext: File extension
:return: Result containing image information
"""
try:
# Read image file
with open(absolute_path, 'rb') as f:
image_data = f.read()
# Get file size
file_size = len(image_data)
# Return image information (actual image data can be base64 encoded when needed)
import base64
base64_data = base64.b64encode(image_data).decode('utf-8')
# Determine MIME type
mime_type_map = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp'
}
mime_type = mime_type_map.get(file_ext, 'image/jpeg')
result = {
"type": "image",
"mime_type": mime_type,
"size": file_size,
"size_formatted": format_size(file_size),
"data": base64_data # Base64 encoded image data
}
return ToolResult.success(result)
except Exception as e:
return ToolResult.fail(f"Error reading image file: {str(e)}")
def _read_text(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
"""
Read text file
:param absolute_path: Absolute path to the file
:param display_path: Path to display
:param offset: Starting line number (1-indexed)
:param limit: Maximum number of lines to read
:return: File content or error message
"""
try:
# Read file
with open(absolute_path, 'r', encoding='utf-8') as f:
content = f.read()
all_lines = content.split('\n')
total_file_lines = len(all_lines)
# Apply offset (if specified)
start_line = 0
if offset is not None:
start_line = max(0, offset - 1) # Convert to 0-indexed
if start_line >= total_file_lines:
return ToolResult.fail(
f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
)
start_line_display = start_line + 1 # For display (1-indexed)
# If user specified limit, use it
selected_content = content
user_limited_lines = None
if limit is not None:
end_line = min(start_line + limit, total_file_lines)
selected_content = '\n'.join(all_lines[start_line:end_line])
user_limited_lines = end_line - start_line
elif offset is not None:
selected_content = '\n'.join(all_lines[start_line:])
# Apply truncation (considering line count and byte limits)
truncation = truncate_head(selected_content)
output_text = ""
details = {}
if truncation.first_line_exceeds_limit:
# First line exceeds 30KB limit
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
output_text = f"[Line {start_line_display} is {first_line_size}, exceeds {format_size(DEFAULT_MAX_BYTES)} limit. Use bash tool to read: head -c {DEFAULT_MAX_BYTES} {display_path} | tail -n +{start_line_display}]"
details["truncation"] = truncation.to_dict()
elif truncation.truncated:
# Truncation occurred
end_line_display = start_line_display + truncation.output_lines - 1
next_offset = end_line_display + 1
output_text = truncation.content
if truncation.truncated_by == "lines":
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines}. Use offset={next_offset} to continue.]"
else:
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
details["truncation"] = truncation.to_dict()
elif user_limited_lines is not None and start_line + user_limited_lines < total_file_lines:
# User specified limit, more content available, but no truncation
remaining = total_file_lines - (start_line + user_limited_lines)
next_offset = start_line + user_limited_lines + 1
output_text = truncation.content
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
else:
# No truncation, no exceeding user limit
output_text = truncation.content
result = {
"content": output_text,
"total_lines": total_file_lines,
"start_line": start_line_display,
"output_lines": truncation.output_lines
}
if details:
result["details"] = details
return ToolResult.success(result)
except UnicodeDecodeError:
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {display_path}")
except Exception as e:
return ToolResult.fail(f"Error reading file: {str(e)}")
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
"""
Read PDF file content
:param absolute_path: Absolute path to the file
:param display_path: Path to display
:param offset: Starting line number (1-indexed)
:param limit: Maximum number of lines to read
:return: PDF text content or error message
"""
try:
# Try to import pypdf
try:
from pypdf import PdfReader
except ImportError:
return ToolResult.fail(
"Error: pypdf library not installed. Install with: pip install pypdf"
)
# Read PDF
reader = PdfReader(absolute_path)
total_pages = len(reader.pages)
# Extract text from all pages
text_parts = []
for page_num, page in enumerate(reader.pages, 1):
page_text = page.extract_text()
if page_text.strip():
text_parts.append(f"--- Page {page_num} ---\n{page_text}")
if not text_parts:
return ToolResult.success({
"content": f"[PDF file with {total_pages} pages, but no text content could be extracted]",
"total_pages": total_pages,
"message": "PDF may contain only images or be encrypted"
})
# Merge all text
full_content = "\n\n".join(text_parts)
all_lines = full_content.split('\n')
total_lines = len(all_lines)
# Apply offset and limit (same logic as text files)
start_line = 0
if offset is not None:
start_line = max(0, offset - 1)
if start_line >= total_lines:
return ToolResult.fail(
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
)
start_line_display = start_line + 1
selected_content = full_content
user_limited_lines = None
if limit is not None:
end_line = min(start_line + limit, total_lines)
selected_content = '\n'.join(all_lines[start_line:end_line])
user_limited_lines = end_line - start_line
elif offset is not None:
selected_content = '\n'.join(all_lines[start_line:])
# Apply truncation
truncation = truncate_head(selected_content)
output_text = ""
details = {}
if truncation.truncated:
end_line_display = start_line_display + truncation.output_lines - 1
next_offset = end_line_display + 1
output_text = truncation.content
if truncation.truncated_by == "lines":
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
else:
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
details["truncation"] = truncation.to_dict()
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
remaining = total_lines - (start_line + user_limited_lines)
next_offset = start_line + user_limited_lines + 1
output_text = truncation.content
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
else:
output_text = truncation.content
result = {
"content": output_text,
"total_pages": total_pages,
"total_lines": total_lines,
"start_line": start_line_display,
"output_lines": truncation.output_lines
}
if details:
result["details"] = details
return ToolResult.success(result)
except Exception as e:
return ToolResult.fail(f"Error reading PDF file: {str(e)}")

248
agent/tools/tool_manager.py Normal file
View File

@@ -0,0 +1,248 @@
import importlib
import importlib.util
from pathlib import Path
from typing import Dict, Any, Type
from agent.tools.base_tool import BaseTool
from common.log import logger
from config import conf
class ToolManager:
"""
Tool manager for managing tools.
"""
_instance = None
def __new__(cls):
"""Singleton pattern to ensure only one instance of ToolManager exists."""
if cls._instance is None:
cls._instance = super(ToolManager, cls).__new__(cls)
cls._instance.tool_classes = {} # Store tool classes instead of instances
cls._instance._initialized = False
return cls._instance
def __init__(self):
# Initialize only once
if not hasattr(self, 'tool_classes'):
self.tool_classes = {} # Dictionary to store tool classes
def load_tools(self, tools_dir: str = "", config_dict=None):
"""
Load tools from both directory and configuration.
:param tools_dir: Directory to scan for tool modules
"""
if tools_dir:
self._load_tools_from_directory(tools_dir)
self._configure_tools_from_config()
else:
self._load_tools_from_init()
self._configure_tools_from_config(config_dict)
def _load_tools_from_init(self) -> bool:
"""
Load tool classes from tools.__init__.__all__
:return: True if tools were loaded, False otherwise
"""
try:
# Try to import the tools package
tools_package = importlib.import_module("agent.tools")
# Check if __all__ is defined
if hasattr(tools_package, "__all__"):
tool_classes = tools_package.__all__
# Import each tool class directly from the tools package
for class_name in tool_classes:
try:
# Skip base classes
if class_name in ["BaseTool", "ToolManager"]:
continue
# Get the class directly from the tools package
if hasattr(tools_package, class_name):
cls = getattr(tools_package, class_name)
if (
isinstance(cls, type)
and issubclass(cls, BaseTool)
and cls != BaseTool
):
try:
# Skip memory tools (they need special initialization with memory_manager)
if class_name in ["MemorySearchTool", "MemoryGetTool"]:
logger.debug(f"Skipped tool {class_name} (requires memory_manager)")
continue
# Create a temporary instance to get the name
temp_instance = cls()
tool_name = temp_instance.name
# Store the class, not the instance
self.tool_classes[tool_name] = cls
logger.debug(f"Loaded tool: {tool_name} from class {class_name}")
except ImportError as e:
# Handle missing dependencies with helpful messages
error_msg = str(e)
if "browser-use" in error_msg or "browser_use" in error_msg:
logger.warning(
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
f" To enable browser tool, run:\n"
f" pip install browser-use markdownify playwright\n"
f" playwright install chromium"
)
elif "markdownify" in error_msg:
logger.warning(
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
f" Install with: pip install markdownify"
)
else:
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
except Exception as e:
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
except Exception as e:
logger.error(f"Error importing class {class_name}: {e}")
return len(self.tool_classes) > 0
return False
except ImportError:
logger.warning("Could not import agent.tools package")
return False
except Exception as e:
logger.error(f"Error loading tools from __init__.__all__: {e}")
return False
def _load_tools_from_directory(self, tools_dir: str):
"""Dynamically load tool classes from directory"""
tools_path = Path(tools_dir)
# Traverse all .py files
for py_file in tools_path.rglob("*.py"):
# Skip initialization files and base tool files
if py_file.name in ["__init__.py", "base_tool.py", "tool_manager.py"]:
continue
# Get module name
module_name = py_file.stem
try:
# Load module directly from file
spec = importlib.util.spec_from_file_location(module_name, py_file)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Find tool classes in the module
for attr_name in dir(module):
cls = getattr(module, attr_name)
if (
isinstance(cls, type)
and issubclass(cls, BaseTool)
and cls != BaseTool
):
try:
# Skip memory tools (they need special initialization with memory_manager)
if attr_name in ["MemorySearchTool", "MemoryGetTool"]:
logger.debug(f"Skipped tool {attr_name} (requires memory_manager)")
continue
# Create a temporary instance to get the name
temp_instance = cls()
tool_name = temp_instance.name
# Store the class, not the instance
self.tool_classes[tool_name] = cls
except ImportError as e:
# Handle missing dependencies with helpful messages
error_msg = str(e)
if "browser-use" in error_msg or "browser_use" in error_msg:
logger.warning(
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
f" To enable browser tool, run:\n"
f" pip install browser-use markdownify playwright\n"
f" playwright install chromium"
)
elif "markdownify" in error_msg:
logger.warning(
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
f" Install with: pip install markdownify"
)
else:
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
except Exception as e:
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
except Exception as e:
print(f"Error importing module {py_file}: {e}")
def _configure_tools_from_config(self, config_dict=None):
"""Configure tool classes based on configuration file"""
try:
# Get tools configuration
tools_config = config_dict or conf().get("tools", {})
# Record tools that are configured but not loaded
missing_tools = []
# Store configurations for later use when instantiating
self.tool_configs = tools_config
# Check which configured tools are missing
for tool_name in tools_config:
if tool_name not in self.tool_classes:
missing_tools.append(tool_name)
# If there are missing tools, record warnings
if missing_tools:
for tool_name in missing_tools:
if tool_name == "browser":
logger.warning(
f"[ToolManager] Browser tool is configured but not loaded.\n"
f" To enable browser tool, run:\n"
f" pip install browser-use markdownify playwright\n"
f" playwright install chromium"
)
elif tool_name == "google_search":
logger.warning(
f"[ToolManager] Google Search tool is configured but may need API key.\n"
f" Get API key from: https://serper.dev\n"
f" Configure in config.json: tools.google_search.api_key"
)
else:
logger.warning(f"[ToolManager] Tool '{tool_name}' is configured but could not be loaded.")
except Exception as e:
logger.error(f"Error configuring tools from config: {e}")
def create_tool(self, name: str) -> BaseTool:
"""
Get a new instance of a tool by name.
:param name: The name of the tool to get.
:return: A new instance of the tool or None if not found.
"""
tool_class = self.tool_classes.get(name)
if tool_class:
# Create a new instance
tool_instance = tool_class()
# Apply configuration if available
if hasattr(self, 'tool_configs') and name in self.tool_configs:
tool_instance.config = self.tool_configs[name]
return tool_instance
return None
def list_tools(self) -> dict:
"""
Get information about all loaded tools.
:return: A dictionary with tool information.
"""
result = {}
for name, tool_class in self.tool_classes.items():
# Create a temporary instance to get schema
temp_instance = tool_class()
result[name] = {
"description": temp_instance.description,
"parameters": temp_instance.get_json_schema()
}
return result

View File

@@ -0,0 +1,40 @@
from .truncate import (
truncate_head,
truncate_tail,
truncate_line,
format_size,
TruncationResult,
DEFAULT_MAX_LINES,
DEFAULT_MAX_BYTES,
GREP_MAX_LINE_LENGTH
)
from .diff import (
strip_bom,
detect_line_ending,
normalize_to_lf,
restore_line_endings,
normalize_for_fuzzy_match,
fuzzy_find_text,
generate_diff_string,
FuzzyMatchResult
)
__all__ = [
'truncate_head',
'truncate_tail',
'truncate_line',
'format_size',
'TruncationResult',
'DEFAULT_MAX_LINES',
'DEFAULT_MAX_BYTES',
'GREP_MAX_LINE_LENGTH',
'strip_bom',
'detect_line_ending',
'normalize_to_lf',
'restore_line_endings',
'normalize_for_fuzzy_match',
'fuzzy_find_text',
'generate_diff_string',
'FuzzyMatchResult'
]

167
agent/tools/utils/diff.py Normal file
View File

@@ -0,0 +1,167 @@
"""
Diff tools for file editing
Provides fuzzy matching and diff generation functionality
"""
import difflib
import re
from typing import Optional, Tuple
def strip_bom(text: str) -> Tuple[str, str]:
"""
Remove BOM (Byte Order Mark)
:param text: Original text
:return: (BOM, text after removing BOM)
"""
if text.startswith('\ufeff'):
return '\ufeff', text[1:]
return '', text
def detect_line_ending(text: str) -> str:
"""
Detect line ending type
:param text: Text content
:return: Line ending type ('\r\n' or '\n')
"""
if '\r\n' in text:
return '\r\n'
return '\n'
def normalize_to_lf(text: str) -> str:
"""
Normalize all line endings to LF (\n)
:param text: Original text
:return: Normalized text
"""
return text.replace('\r\n', '\n').replace('\r', '\n')
def restore_line_endings(text: str, original_ending: str) -> str:
"""
Restore original line endings
:param text: LF normalized text
:param original_ending: Original line ending
:return: Text with restored line endings
"""
if original_ending == '\r\n':
return text.replace('\n', '\r\n')
return text
def normalize_for_fuzzy_match(text: str) -> str:
"""
Normalize text for fuzzy matching
Remove excess whitespace but preserve basic structure
:param text: Original text
:return: Normalized text
"""
# Compress multiple spaces to one
text = re.sub(r'[ \t]+', ' ', text)
# Remove trailing spaces
text = re.sub(r' +\n', '\n', text)
# Remove leading spaces (but preserve indentation structure, only remove excess)
lines = text.split('\n')
normalized_lines = []
for line in lines:
# Preserve indentation but normalize to multiples of single spaces
stripped = line.lstrip()
if stripped:
indent_count = len(line) - len(stripped)
# Normalize indentation (convert tabs to spaces)
normalized_indent = ' ' * indent_count
normalized_lines.append(normalized_indent + stripped)
else:
normalized_lines.append('')
return '\n'.join(normalized_lines)
class FuzzyMatchResult:
"""Fuzzy match result"""
def __init__(self, found: bool, index: int = -1, match_length: int = 0, content_for_replacement: str = ""):
self.found = found
self.index = index
self.match_length = match_length
self.content_for_replacement = content_for_replacement
def fuzzy_find_text(content: str, old_text: str) -> FuzzyMatchResult:
"""
Find text in content, try exact match first, then fuzzy match
:param content: Content to search in
:param old_text: Text to find
:return: Match result
"""
# First try exact match
index = content.find(old_text)
if index != -1:
return FuzzyMatchResult(
found=True,
index=index,
match_length=len(old_text),
content_for_replacement=content
)
# Try fuzzy match
fuzzy_content = normalize_for_fuzzy_match(content)
fuzzy_old_text = normalize_for_fuzzy_match(old_text)
index = fuzzy_content.find(fuzzy_old_text)
if index != -1:
# Fuzzy match successful, use normalized content for replacement
return FuzzyMatchResult(
found=True,
index=index,
match_length=len(fuzzy_old_text),
content_for_replacement=fuzzy_content
)
# Not found
return FuzzyMatchResult(found=False)
def generate_diff_string(old_content: str, new_content: str) -> dict:
"""
Generate unified diff string
:param old_content: Old content
:param new_content: New content
:return: Dictionary containing diff and first changed line number
"""
old_lines = old_content.split('\n')
new_lines = new_content.split('\n')
# Generate unified diff
diff_lines = list(difflib.unified_diff(
old_lines,
new_lines,
lineterm='',
fromfile='original',
tofile='modified'
))
# Find first changed line number
first_changed_line = None
for line in diff_lines:
if line.startswith('@@'):
# Parse @@ -1,3 +1,3 @@ format
match = re.search(r'@@ -\d+,?\d* \+(\d+)', line)
if match:
first_changed_line = int(match.group(1))
break
diff_string = '\n'.join(diff_lines)
return {
'diff': diff_string,
'first_changed_line': first_changed_line
}

View File

@@ -0,0 +1,292 @@
"""
Shared truncation utilities for tool outputs.
Truncation is based on two independent limits - whichever is hit first wins:
- Line limit (default: 2000 lines)
- Byte limit (default: 50KB)
Never returns partial lines (except bash tail truncation edge case).
"""
from typing import Dict, Any, Optional, Literal
DEFAULT_MAX_LINES = 2000
DEFAULT_MAX_BYTES = 50 * 1024 # 50KB
GREP_MAX_LINE_LENGTH = 500 # Max chars per grep match line
class TruncationResult:
"""Truncation result"""
def __init__(
self,
content: str,
truncated: bool,
truncated_by: Optional[Literal["lines", "bytes"]],
total_lines: int,
total_bytes: int,
output_lines: int,
output_bytes: int,
last_line_partial: bool = False,
first_line_exceeds_limit: bool = False,
max_lines: int = DEFAULT_MAX_LINES,
max_bytes: int = DEFAULT_MAX_BYTES
):
self.content = content
self.truncated = truncated
self.truncated_by = truncated_by
self.total_lines = total_lines
self.total_bytes = total_bytes
self.output_lines = output_lines
self.output_bytes = output_bytes
self.last_line_partial = last_line_partial
self.first_line_exceeds_limit = first_line_exceeds_limit
self.max_lines = max_lines
self.max_bytes = max_bytes
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"content": self.content,
"truncated": self.truncated,
"truncated_by": self.truncated_by,
"total_lines": self.total_lines,
"total_bytes": self.total_bytes,
"output_lines": self.output_lines,
"output_bytes": self.output_bytes,
"last_line_partial": self.last_line_partial,
"first_line_exceeds_limit": self.first_line_exceeds_limit,
"max_lines": self.max_lines,
"max_bytes": self.max_bytes
}
def format_size(bytes_count: int) -> str:
"""Format bytes as human-readable size"""
if bytes_count < 1024:
return f"{bytes_count}B"
elif bytes_count < 1024 * 1024:
return f"{bytes_count / 1024:.1f}KB"
else:
return f"{bytes_count / (1024 * 1024):.1f}MB"
def truncate_head(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
"""
Truncate content from the head (keep first N lines/bytes).
Suitable for file reads where you want to see the beginning.
Never returns partial lines. If first line exceeds byte limit,
returns empty content with first_line_exceeds_limit=True.
:param content: Content to truncate
:param max_lines: Maximum number of lines (default: 2000)
:param max_bytes: Maximum number of bytes (default: 50KB)
:return: Truncation result
"""
if max_lines is None:
max_lines = DEFAULT_MAX_LINES
if max_bytes is None:
max_bytes = DEFAULT_MAX_BYTES
total_bytes = len(content.encode('utf-8'))
lines = content.split('\n')
total_lines = len(lines)
# Check if no truncation is needed
if total_lines <= max_lines and total_bytes <= max_bytes:
return TruncationResult(
content=content,
truncated=False,
truncated_by=None,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=total_lines,
output_bytes=total_bytes,
last_line_partial=False,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
# Check if first line alone exceeds byte limit
first_line_bytes = len(lines[0].encode('utf-8'))
if first_line_bytes > max_bytes:
return TruncationResult(
content="",
truncated=True,
truncated_by="bytes",
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=0,
output_bytes=0,
last_line_partial=False,
first_line_exceeds_limit=True,
max_lines=max_lines,
max_bytes=max_bytes
)
# Collect complete lines that fit
output_lines_arr = []
output_bytes_count = 0
truncated_by = "lines"
for i, line in enumerate(lines):
if i >= max_lines:
break
# Calculate line bytes (add 1 for newline if not first line)
line_bytes = len(line.encode('utf-8')) + (1 if i > 0 else 0)
if output_bytes_count + line_bytes > max_bytes:
truncated_by = "bytes"
break
output_lines_arr.append(line)
output_bytes_count += line_bytes
# If exited due to line limit
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
truncated_by = "lines"
output_content = '\n'.join(output_lines_arr)
final_output_bytes = len(output_content.encode('utf-8'))
return TruncationResult(
content=output_content,
truncated=True,
truncated_by=truncated_by,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=len(output_lines_arr),
output_bytes=final_output_bytes,
last_line_partial=False,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
def truncate_tail(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
"""
Truncate content from tail (keep last N lines/bytes).
Suitable for bash output where you want to see the ending content (errors, final results).
If the last line of original content exceeds byte limit, may return partial first line.
:param content: Content to truncate
:param max_lines: Maximum lines (default: 2000)
:param max_bytes: Maximum bytes (default: 50KB)
:return: Truncation result
"""
if max_lines is None:
max_lines = DEFAULT_MAX_LINES
if max_bytes is None:
max_bytes = DEFAULT_MAX_BYTES
total_bytes = len(content.encode('utf-8'))
lines = content.split('\n')
total_lines = len(lines)
# Check if no truncation is needed
if total_lines <= max_lines and total_bytes <= max_bytes:
return TruncationResult(
content=content,
truncated=False,
truncated_by=None,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=total_lines,
output_bytes=total_bytes,
last_line_partial=False,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
# Work backwards from the end
output_lines_arr = []
output_bytes_count = 0
truncated_by = "lines"
last_line_partial = False
for i in range(len(lines) - 1, -1, -1):
if len(output_lines_arr) >= max_lines:
break
line = lines[i]
# Calculate line bytes (add newline if not the first added line)
line_bytes = len(line.encode('utf-8')) + (1 if len(output_lines_arr) > 0 else 0)
if output_bytes_count + line_bytes > max_bytes:
truncated_by = "bytes"
# Edge case: if we haven't added any lines yet and this line exceeds maxBytes,
# take the end portion of this line
if len(output_lines_arr) == 0:
truncated_line = _truncate_string_to_bytes_from_end(line, max_bytes)
output_lines_arr.insert(0, truncated_line)
output_bytes_count = len(truncated_line.encode('utf-8'))
last_line_partial = True
break
output_lines_arr.insert(0, line)
output_bytes_count += line_bytes
# If exited due to line limit
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
truncated_by = "lines"
output_content = '\n'.join(output_lines_arr)
final_output_bytes = len(output_content.encode('utf-8'))
return TruncationResult(
content=output_content,
truncated=True,
truncated_by=truncated_by,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=len(output_lines_arr),
output_bytes=final_output_bytes,
last_line_partial=last_line_partial,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
"""
Truncate string to fit byte limit (from end).
Properly handles multi-byte UTF-8 characters.
:param text: String to truncate
:param max_bytes: Maximum bytes
:return: Truncated string
"""
encoded = text.encode('utf-8')
if len(encoded) <= max_bytes:
return text
# Start from end, skip back maxBytes
start = len(encoded) - max_bytes
# Find valid UTF-8 boundary (character start)
while start < len(encoded) and (encoded[start] & 0xC0) == 0x80:
start += 1
return encoded[start:].decode('utf-8', errors='ignore')
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> tuple[str, bool]:
"""
Truncate single line to max characters, add [truncated] suffix.
Used for grep match lines.
:param line: Line to truncate
:param max_chars: Maximum characters
:return: (truncated text, whether truncated)
"""
if len(line) <= max_chars:
return line, False
return f"{line[:max_chars]}... [truncated]", True

View File

@@ -0,0 +1,212 @@
# WebFetch Tool
免费的网页抓取工具,无需 API Key可直接抓取网页内容并提取可读文本。
## 功能特性
-**完全免费** - 无需任何 API Key
- 🌐 **智能提取** - 自动提取网页主要内容
- 📝 **格式转换** - 支持 HTML → Markdown/Text
- 🚀 **高性能** - 内置请求重试和超时控制
- 🎯 **智能降级** - 优先使用 Readability可降级到基础提取
## 安装依赖
### 基础功能(必需)
```bash
pip install requests
```
### 增强功能(推荐)
```bash
# 安装 readability-lxml 以获得更好的内容提取效果
pip install readability-lxml
# 安装 html2text 以获得更好的 Markdown 转换
pip install html2text
```
## 使用方法
### 1. 在代码中使用
```python
from agent.tools.web_fetch import WebFetch
# 创建工具实例
tool = WebFetch()
# 抓取网页(默认返回 Markdown 格式)
result = tool.execute({
"url": "https://example.com"
})
# 抓取并转换为纯文本
result = tool.execute({
"url": "https://example.com",
"extract_mode": "text",
"max_chars": 5000
})
if result.status == "success":
data = result.result
print(f"标题: {data['title']}")
print(f"内容: {data['text']}")
```
### 2. 在 Agent 中使用
工具会自动加载到 Agent 的工具列表中:
```python
from agent.tools import WebFetch
tools = [
WebFetch(),
# ... 其他工具
]
agent = create_agent(tools=tools)
```
### 3. 通过 Skills 使用
创建一个 skill 文件 `skills/web-fetch/SKILL.md`
```markdown
---
name: web-fetch
emoji: 🌐
always: true
---
# 网页内容获取
使用 web_fetch 工具获取网页内容。
## 使用场景
- 需要读取某个网页的内容
- 需要提取文章正文
- 需要获取网页信息
## 示例
<example>
用户: 帮我看看 https://example.com 这个网页讲了什么
助手: <tool_use name="web_fetch">
<url>https://example.com</url>
<extract_mode>markdown</extract_mode>
</tool_use>
</example>
```
## 参数说明
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------|
| `url` | string | ✅ | - | 要抓取的 URLhttp/https |
| `extract_mode` | string | ❌ | `markdown` | 提取模式:`markdown``text` |
| `max_chars` | integer | ❌ | `50000` | 最大返回字符数(最小 100 |
## 返回结果
```python
{
"url": "https://example.com", # 最终 URL处理重定向后
"status": 200, # HTTP 状态码
"content_type": "text/html", # 内容类型
"title": "Example Domain", # 页面标题
"extractor": "readability", # 提取器readability/basic/raw
"extract_mode": "markdown", # 提取模式
"text": "# Example Domain\n\n...", # 提取的文本内容
"length": 1234, # 文本长度
"truncated": false, # 是否被截断
"warning": "..." # 警告信息(如果有)
}
```
## 与其他搜索工具的对比
| 工具 | 需要 API Key | 功能 | 成本 |
|------|-------------|------|------|
| `web_fetch` | ❌ 不需要 | 抓取指定 URL 的内容 | 免费 |
| `web_search` (Brave) | ✅ 需要 | 搜索引擎查询 | 有免费额度 |
| `web_search` (Perplexity) | ✅ 需要 | AI 搜索 + 引用 | 付费 |
| `browser` | ❌ 不需要 | 完整浏览器自动化 | 免费但资源占用大 |
| `google_search` | ✅ 需要 | Google 搜索 API | 付费 |
## 技术细节
### 内容提取策略
1. **Readability 模式**(推荐)
- 使用 Mozilla 的 Readability 算法
- 自动识别文章主体内容
- 过滤广告、导航栏等噪音
2. **Basic 模式**(降级)
- 简单的 HTML 标签清理
- 正则表达式提取文本
- 适用于简单页面
3. **Raw 模式**
- 用于非 HTML 内容
- 直接返回原始内容
### 错误处理
工具会自动处理以下情况:
- ✅ HTTP 重定向(最多 3 次)
- ✅ 请求超时(默认 30 秒)
- ✅ 网络错误自动重试
- ✅ 内容提取失败降级
## 测试
运行测试脚本:
```bash
cd agent/tools/web_fetch
python test_web_fetch.py
```
## 配置选项
在创建工具时可以传入配置:
```python
tool = WebFetch(config={
"timeout": 30, # 请求超时时间(秒)
"max_redirects": 3, # 最大重定向次数
"user_agent": "..." # 自定义 User-Agent
})
```
## 常见问题
### Q: 为什么推荐安装 readability-lxml
A: readability-lxml 提供更好的内容提取质量,能够:
- 自动识别文章主体
- 过滤广告和导航栏
- 保留文章结构
没有它也能工作,但提取质量会下降。
### Q: 与 clawdbot 的 web_fetch 有什么区别?
A: 本实现参考了 clawdbot 的设计,主要区别:
- Python 实现clawdbot 是 TypeScript
- 简化了一些高级特性(如 Firecrawl 集成)
- 保留了核心的免费功能
- 更容易集成到现有项目
### Q: 可以抓取需要登录的页面吗?
A: 当前版本不支持。如需抓取需要登录的页面,请使用 `browser` 工具。
## 参考
- [Mozilla Readability](https://github.com/mozilla/readability)
- [Clawdbot Web Tools](https://github.com/moltbot/moltbot)

View File

@@ -0,0 +1,3 @@
from .web_fetch import WebFetch
__all__ = ['WebFetch']

View File

@@ -0,0 +1,47 @@
#!/bin/bash
# WebFetch 工具依赖安装脚本
echo "=================================="
echo "WebFetch 工具依赖安装"
echo "=================================="
echo ""
# 检查 Python 版本
python_version=$(python3 --version 2>&1 | awk '{print $2}')
echo "✓ Python 版本: $python_version"
echo ""
# 安装基础依赖
echo "📦 安装基础依赖..."
python3 -m pip install requests
# 检查是否成功
if [ $? -eq 0 ]; then
echo "✅ requests 安装成功"
else
echo "❌ requests 安装失败"
exit 1
fi
echo ""
# 安装推荐依赖
echo "📦 安装推荐依赖(提升内容提取质量)..."
python3 -m pip install readability-lxml html2text
# 检查是否成功
if [ $? -eq 0 ]; then
echo "✅ readability-lxml 和 html2text 安装成功"
else
echo "⚠️ 推荐依赖安装失败,但不影响基础功能"
fi
echo ""
echo "=================================="
echo "安装完成!"
echo "=================================="
echo ""
echo "运行测试:"
echo " python3 agent/tools/web_fetch/test_web_fetch.py"
echo ""

View File

@@ -0,0 +1,365 @@
"""
Web Fetch tool - Fetch and extract readable content from URLs
Supports HTML to Markdown/Text conversion using Mozilla's Readability
"""
import os
import re
from typing import Dict, Any, Optional
from urllib.parse import urlparse
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
class WebFetch(BaseTool):
"""Tool for fetching and extracting readable content from web pages"""
name: str = "web_fetch"
description: str = "Fetch and extract readable content from a URL (HTML → markdown/text). Use for lightweight page access without browser automation. Returns title, content, and metadata."
params: dict = {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "HTTP or HTTPS URL to fetch"
},
"extract_mode": {
"type": "string",
"description": "Extraction mode: 'markdown' (default) or 'text'",
"enum": ["markdown", "text"],
"default": "markdown"
},
"max_chars": {
"type": "integer",
"description": "Maximum characters to return (default: 50000)",
"minimum": 100,
"default": 50000
}
},
"required": ["url"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.timeout = self.config.get("timeout", 30)
self.max_redirects = self.config.get("max_redirects", 3)
self.user_agent = self.config.get(
"user_agent",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36"
)
# Setup session with retry strategy
self.session = self._create_session()
# Check if readability-lxml is available
self.readability_available = self._check_readability()
def _create_session(self) -> requests.Session:
"""Create a requests session with retry strategy"""
session = requests.Session()
# Retry strategy - handles failed requests, not redirects
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["GET", "HEAD"]
)
# HTTPAdapter handles retries; requests handles redirects via allow_redirects
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
# Set max redirects on session
session.max_redirects = self.max_redirects
return session
def _check_readability(self) -> bool:
"""Check if readability-lxml is available"""
try:
from readability import Document
return True
except ImportError:
logger.warning(
"readability-lxml not installed. Install with: pip install readability-lxml\n"
"Falling back to basic HTML extraction."
)
return False
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute web fetch operation
:param args: Contains url, extract_mode, and max_chars parameters
:return: Extracted content or error message
"""
url = args.get("url", "").strip()
extract_mode = args.get("extract_mode", "markdown").lower()
max_chars = args.get("max_chars", 50000)
if not url:
return ToolResult.fail("Error: url parameter is required")
# Validate URL
if not self._is_valid_url(url):
return ToolResult.fail(f"Error: Invalid URL (must be http or https): {url}")
# Validate extract_mode
if extract_mode not in ["markdown", "text"]:
extract_mode = "markdown"
# Validate max_chars
if not isinstance(max_chars, int) or max_chars < 100:
max_chars = 50000
try:
# Fetch the URL
response = self._fetch_url(url)
# Extract content
result = self._extract_content(
html=response.text,
url=response.url,
status_code=response.status_code,
content_type=response.headers.get("content-type", ""),
extract_mode=extract_mode,
max_chars=max_chars
)
return ToolResult.success(result)
except requests.exceptions.Timeout:
return ToolResult.fail(f"Error: Request timeout after {self.timeout} seconds")
except requests.exceptions.TooManyRedirects:
return ToolResult.fail(f"Error: Too many redirects (limit: {self.max_redirects})")
except requests.exceptions.RequestException as e:
return ToolResult.fail(f"Error fetching URL: {str(e)}")
except Exception as e:
logger.error(f"Web fetch error: {e}", exc_info=True)
return ToolResult.fail(f"Error: {str(e)}")
def _is_valid_url(self, url: str) -> bool:
"""Validate URL format"""
try:
result = urlparse(url)
return result.scheme in ["http", "https"] and bool(result.netloc)
except Exception:
return False
def _fetch_url(self, url: str) -> requests.Response:
"""
Fetch URL with proper headers and error handling
:param url: URL to fetch
:return: Response object
"""
headers = {
"User-Agent": self.user_agent,
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.9,zh-CN,zh;q=0.8",
"Accept-Encoding": "gzip, deflate",
"Connection": "keep-alive",
}
# Note: requests library handles redirects automatically
# The max_redirects is set in the session's adapter (HTTPAdapter)
response = self.session.get(
url,
headers=headers,
timeout=self.timeout,
allow_redirects=True
)
response.raise_for_status()
return response
def _extract_content(
self,
html: str,
url: str,
status_code: int,
content_type: str,
extract_mode: str,
max_chars: int
) -> Dict[str, Any]:
"""
Extract readable content from HTML
:param html: HTML content
:param url: Original URL
:param status_code: HTTP status code
:param content_type: Content type header
:param extract_mode: 'markdown' or 'text'
:param max_chars: Maximum characters to return
:return: Extracted content and metadata
"""
# Check content type
if "text/html" not in content_type.lower():
# Non-HTML content
text = html[:max_chars]
truncated = len(html) > max_chars
return {
"url": url,
"status": status_code,
"content_type": content_type,
"extractor": "raw",
"text": text,
"length": len(text),
"truncated": truncated,
"message": f"Non-HTML content (type: {content_type})"
}
# Extract readable content from HTML
if self.readability_available:
return self._extract_with_readability(
html, url, status_code, content_type, extract_mode, max_chars
)
else:
return self._extract_basic(
html, url, status_code, content_type, extract_mode, max_chars
)
def _extract_with_readability(
self,
html: str,
url: str,
status_code: int,
content_type: str,
extract_mode: str,
max_chars: int
) -> Dict[str, Any]:
"""Extract content using Mozilla's Readability"""
try:
from readability import Document
# Parse with Readability
doc = Document(html)
title = doc.title()
content_html = doc.summary()
# Convert to markdown or text
if extract_mode == "markdown":
text = self._html_to_markdown(content_html)
else:
text = self._html_to_text(content_html)
# Truncate if needed
truncated = len(text) > max_chars
if truncated:
text = text[:max_chars]
return {
"url": url,
"status": status_code,
"content_type": content_type,
"title": title,
"extractor": "readability",
"extract_mode": extract_mode,
"text": text,
"length": len(text),
"truncated": truncated
}
except Exception as e:
logger.warning(f"Readability extraction failed: {e}")
# Fallback to basic extraction
return self._extract_basic(
html, url, status_code, content_type, extract_mode, max_chars
)
def _extract_basic(
self,
html: str,
url: str,
status_code: int,
content_type: str,
extract_mode: str,
max_chars: int
) -> Dict[str, Any]:
"""Basic HTML extraction without Readability"""
# Extract title
title_match = re.search(r'<title[^>]*>(.*?)</title>', html, re.IGNORECASE | re.DOTALL)
title = title_match.group(1).strip() if title_match else "Untitled"
# Remove script and style tags
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
# Remove HTML tags
text = re.sub(r'<[^>]+>', ' ', text)
# Clean up whitespace
text = re.sub(r'\s+', ' ', text)
text = text.strip()
# Truncate if needed
truncated = len(text) > max_chars
if truncated:
text = text[:max_chars]
return {
"url": url,
"status": status_code,
"content_type": content_type,
"title": title,
"extractor": "basic",
"extract_mode": extract_mode,
"text": text,
"length": len(text),
"truncated": truncated,
"warning": "Using basic extraction. Install readability-lxml for better results."
}
def _html_to_markdown(self, html: str) -> str:
"""Convert HTML to Markdown (basic implementation)"""
try:
# Try to use html2text if available
import html2text
h = html2text.HTML2Text()
h.ignore_links = False
h.ignore_images = False
h.body_width = 0 # Don't wrap lines
return h.handle(html)
except ImportError:
# Fallback to basic conversion
return self._html_to_text(html)
def _html_to_text(self, html: str) -> str:
"""Convert HTML to plain text"""
# Remove script and style tags
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
# Convert common tags to text equivalents
text = re.sub(r'<br\s*/?>', '\n', text, flags=re.IGNORECASE)
text = re.sub(r'<p[^>]*>', '\n\n', text, flags=re.IGNORECASE)
text = re.sub(r'</p>', '', text, flags=re.IGNORECASE)
text = re.sub(r'<h[1-6][^>]*>', '\n\n', text, flags=re.IGNORECASE)
text = re.sub(r'</h[1-6]>', '\n', text, flags=re.IGNORECASE)
# Remove all other HTML tags
text = re.sub(r'<[^>]+>', '', text)
# Decode HTML entities
import html
text = html.unescape(text)
# Clean up whitespace
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
text = re.sub(r' +', ' ', text)
text = text.strip()
return text
def close(self):
"""Close the session"""
if hasattr(self, 'session'):
self.session.close()

View File

@@ -0,0 +1,3 @@
from .write import Write
__all__ = ['Write']

View File

@@ -0,0 +1,96 @@
"""
Write tool - Write file content
Creates or overwrites files, automatically creates parent directories
"""
import os
from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
class Write(BaseTool):
"""Tool for writing file content"""
name: str = "write"
description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to write (relative or absolute)"
},
"content": {
"type": "string",
"description": "Content to write to the file"
}
},
"required": ["path", "content"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
self.memory_manager = self.config.get("memory_manager", None)
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file write operation
:param args: Contains file path and content
:return: Operation result
"""
path = args.get("path", "").strip()
content = args.get("content", "")
if not path:
return ToolResult.fail("Error: path parameter is required")
# Resolve path
absolute_path = self._resolve_path(path)
try:
# Create parent directory (if needed)
parent_dir = os.path.dirname(absolute_path)
if parent_dir:
os.makedirs(parent_dir, exist_ok=True)
# Write file
with open(absolute_path, 'w', encoding='utf-8') as f:
f.write(content)
# Get bytes written
bytes_written = len(content.encode('utf-8'))
# Auto-sync to memory database if this is a memory file
if self.memory_manager and 'memory/' in path:
self.memory_manager.mark_dirty()
result = {
"message": f"Successfully wrote {bytes_written} bytes to {path}",
"path": path,
"bytes_written": bytes_written
}
return ToolResult.success(result)
except PermissionError:
return ToolResult.fail(f"Error: Permission denied writing to {path}")
except Exception as e:
return ToolResult.fail(f"Error writing file: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""
Resolve path to absolute path
:param path: Relative or absolute path
:return: Absolute path
"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -1,12 +1,14 @@
# encoding:utf-8
import time
import json
import openai
import openai.error
import requests
from common import const
from bot.bot import Bot
from bot.openai_compatible_bot import OpenAICompatibleBot
from bot.chatgpt.chat_gpt_session import ChatGPTSession
from bot.openai.open_ai_image import OpenAIImage
from bot.session_manager import SessionManager
@@ -18,7 +20,7 @@ from config import conf, load_config
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
# OpenAI对话模型API (可用)
class ChatGPTBot(Bot, OpenAIImage):
class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot):
def __init__(self):
super().__init__()
# set the default api_key
@@ -52,6 +54,18 @@ class ChatGPTBot(Bot, OpenAIImage):
if conf_model in [const.O1, const.O1_MINI]: # o1系列模型不支持系统提示词使用文心模型的session
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or const.O1_MINI)
def get_api_config(self):
"""Get API configuration for OpenAI-compatible base class"""
return {
'api_key': conf().get("open_ai_api_key"),
'api_base': conf().get("open_ai_api_base"),
'model': conf().get("model", "gpt-3.5-turbo"),
'default_temperature': conf().get("temperature", 0.9),
'default_top_p': conf().get("top_p", 1.0),
'default_frequency_penalty': conf().get("frequency_penalty", 0.0),
'default_presence_penalty': conf().get("presence_penalty", 0.0),
}
def reply(self, query, context=None):
# acquire reply content
if context.type == ContextType.TEXT:
@@ -171,7 +185,6 @@ class ChatGPTBot(Bot, OpenAIImage):
else:
return result
class AzureChatGPTBot(ChatGPTBot):
def __init__(self):
super().__init__()

View File

@@ -1,222 +0,0 @@
import re
import time
import json
import uuid
from curl_cffi import requests
from bot.bot import Bot
from bot.claude.claude_ai_session import ClaudeAiSession
from bot.openai.open_ai_image import OpenAIImage
from bot.session_manager import SessionManager
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf
class ClaudeAIBot(Bot, OpenAIImage):
def __init__(self):
super().__init__()
self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
self.claude_api_cookie = conf().get("claude_api_cookie")
self.proxy = conf().get("proxy")
self.con_uuid_dic = {}
if self.proxy:
self.proxies = {
"http": self.proxy,
"https": self.proxy
}
else:
self.proxies = None
self.error = ""
self.org_uuid = self.get_organization_id()
def generate_uuid(self):
random_uuid = uuid.uuid4()
random_uuid_str = str(random_uuid)
formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
return formatted_uuid
def reply(self, query, context: Context = None) -> Reply:
if context.type == ContextType.TEXT:
return self._chat(query, context)
elif context.type == ContextType.IMAGE_CREATE:
ok, res = self.create_img(query, 0)
if ok:
reply = Reply(ReplyType.IMAGE_URL, res)
else:
reply = Reply(ReplyType.ERROR, res)
return reply
else:
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply
def get_organization_id(self):
url = "https://claude.ai/api/organizations"
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
'Accept-Language': 'en-US,en;q=0.5',
'Referer': 'https://claude.ai/chats',
'Content-Type': 'application/json',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'Connection': 'keep-alive',
'Cookie': f'{self.claude_api_cookie}'
}
try:
response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
res = json.loads(response.text)
uuid = res[0]['uuid']
except:
if "App unavailable" in response.text:
logger.error("IP error: The IP is not allowed to be used on Claude")
self.error = "ip所在地区不被claude支持"
elif "Invalid authorization" in response.text:
logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
self.error = "无法通过claude身份验证请检查cookie"
return None
return uuid
def conversation_share_check(self,session_id):
if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
con_uuid = conf().get("claude_uuid")
return con_uuid
if session_id not in self.con_uuid_dic:
self.con_uuid_dic[session_id] = self.generate_uuid()
self.create_new_chat(self.con_uuid_dic[session_id])
return self.con_uuid_dic[session_id]
def check_cookie(self):
flag = self.get_organization_id()
return flag
def create_new_chat(self, con_uuid):
"""
新建claude对话实体
:param con_uuid: 对话id
:return:
"""
url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
payload = json.dumps({"uuid": con_uuid, "name": ""})
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
'Accept-Language': 'en-US,en;q=0.5',
'Referer': 'https://claude.ai/chats',
'Content-Type': 'application/json',
'Origin': 'https://claude.ai',
'DNT': '1',
'Connection': 'keep-alive',
'Cookie': self.claude_api_cookie,
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'TE': 'trailers'
}
response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
# Returns JSON of the newly created conversation information
return response.json()
def _chat(self, query, context, retry_count=0) -> Reply:
"""
发起对话请求
:param query: 请求提示词
:param context: 对话上下文
:param retry_count: 当前递归重试次数
:return: 回复
"""
if retry_count >= 2:
# exit from retry 2 times
logger.warn("[CLAUDEAI] failed after maximum number of retry times")
return Reply(ReplyType.ERROR, "请再问我一次吧")
try:
session_id = context["session_id"]
if self.org_uuid is None:
return Reply(ReplyType.ERROR, self.error)
session = self.sessions.session_query(query, session_id)
con_uuid = self.conversation_share_check(session_id)
model = conf().get("model") or "gpt-3.5-turbo"
# remove system message
if session.messages[0].get("role") == "system":
if model == "wenxin" or model == "claude":
session.messages.pop(0)
logger.info(f"[CLAUDEAI] query={query}")
# do http request
base_url = "https://claude.ai"
payload = json.dumps({
"completion": {
"prompt": f"{query}",
"timezone": "Asia/Kolkata",
"model": "claude-2"
},
"organization_uuid": f"{self.org_uuid}",
"conversation_uuid": f"{con_uuid}",
"text": f"{query}",
"attachments": []
})
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
'Accept': 'text/event-stream, text/event-stream',
'Accept-Language': 'en-US,en;q=0.5',
'Referer': 'https://claude.ai/chats',
'Content-Type': 'application/json',
'Origin': 'https://claude.ai',
'DNT': '1',
'Connection': 'keep-alive',
'Cookie': f'{self.claude_api_cookie}',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'TE': 'trailers'
}
res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
if res.status_code == 200 or "pemission" in res.text:
# execute success
decoded_data = res.content.decode("utf-8")
decoded_data = re.sub('\n+', '\n', decoded_data).strip()
data_strings = decoded_data.split('\n')
completions = []
for data_string in data_strings:
json_str = data_string[6:].strip()
data = json.loads(json_str)
if 'completion' in data:
completions.append(data['completion'])
reply_content = ''.join(completions)
if "rate limi" in reply_content:
logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
return Reply(ReplyType.ERROR, "对话达到系统速率限制与cladue同步请进入官网查看解除限制时间")
logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
self.sessions.session_reply(reply_content, session_id, 100)
return Reply(ReplyType.TEXT, reply_content)
else:
flag = self.check_cookie()
if flag == None:
return Reply(ReplyType.ERROR, self.error)
response = res.json()
error = response.get("error")
logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
if res.status_code >= 500:
# server error, need retry
time.sleep(2)
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1)
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
except Exception as e:
logger.exception(e)
# retry
time.sleep(2)
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1)

View File

@@ -1,9 +0,0 @@
from bot.session_manager import Session
class ClaudeAiSession(Session):
def __init__(self, session_id, system_prompt=None, model="claude"):
super().__init__(session_id, system_prompt)
self.model = model
# claude逆向不支持role prompt
# self.reset()

View File

@@ -1,21 +1,28 @@
# encoding:utf-8
import json
import time
import openai
import openai.error
import anthropic
import requests
from bot.bot import Bot
from bot.openai.open_ai_image import OpenAIImage
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
from bot.bot import Bot
from bot.session_manager import SessionManager
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from common import const
from common.log import logger
from config import conf
# Optional OpenAI image support
try:
from bot.openai.open_ai_image import OpenAIImage
_openai_image_available = True
except Exception as e:
logger.warning(f"OpenAI image support not available: {e}")
_openai_image_available = False
OpenAIImage = object # Fallback to object
user_session = dict()
@@ -23,13 +30,9 @@ user_session = dict()
class ClaudeAPIBot(Bot, OpenAIImage):
def __init__(self):
super().__init__()
proxy = conf().get("proxy", None)
base_url = conf().get("open_ai_api_base", None) # 复用"open_ai_api_base"参数作为base_url
self.claudeClient = anthropic.Anthropic(
api_key=conf().get("claude_api_key"),
proxies=proxy if proxy else None,
base_url=base_url if base_url else None
)
self.api_key = conf().get("claude_api_key")
self.api_base = conf().get("open_ai_api_base") or "https://api.anthropic.com/v1"
self.proxy = conf().get("proxy", None)
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "text-davinci-003")
def reply(self, query, context=None):
@@ -73,39 +76,104 @@ class ClaudeAPIBot(Bot, OpenAIImage):
reply = Reply(ReplyType.ERROR, retstring)
return reply
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
def reply_text(self, session: BaiduWenxinSession, retry_count=0, tools=None):
try:
actual_model = self._model_mapping(conf().get("model"))
response = self.claudeClient.messages.create(
model=actual_model,
max_tokens=4096,
system=conf().get("character_desc", ""),
messages=session.messages
# Prepare headers
headers = {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
# Extract system prompt if present and prepare Claude-compatible messages
system_prompt = conf().get("character_desc", "")
claude_messages = []
for msg in session.messages:
if msg.get("role") == "system":
system_prompt = msg["content"]
else:
claude_messages.append(msg)
# Prepare request data
data = {
"model": actual_model,
"messages": claude_messages,
"max_tokens": self._get_max_tokens(actual_model)
}
if system_prompt:
data["system"] = system_prompt
if tools:
data["tools"] = tools
# Make HTTP request
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
response = requests.post(
f"{self.api_base}/messages",
headers=headers,
json=data,
proxies=proxies
)
# response = openai.Completion.create(prompt=str(session), **self.args)
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
total_tokens = response.usage.input_tokens+response.usage.output_tokens
completion_tokens = response.usage.output_tokens
if response.status_code != 200:
raise Exception(f"API request failed: {response.status_code} - {response.text}")
claude_response = response.json()
# Handle response content and tool calls
res_content = ""
tool_calls = []
content_blocks = claude_response.get("content", [])
for block in content_blocks:
if block.get("type") == "text":
res_content += block.get("text", "")
elif block.get("type") == "tool_use":
tool_calls.append({
"id": block.get("id", ""),
"name": block.get("name", ""),
"arguments": block.get("input", {})
})
res_content = res_content.strip().replace("<|endoftext|>", "")
usage = claude_response.get("usage", {})
total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
completion_tokens = usage.get("output_tokens", 0)
logger.info("[CLAUDE_API] reply={}".format(res_content))
return {
if tool_calls:
logger.info("[CLAUDE_API] tool_calls={}".format(tool_calls))
result = {
"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"content": res_content,
}
if tool_calls:
result["tool_calls"] = tool_calls
return result
except Exception as e:
need_retry = retry_count < 2
result = {"total_tokens": 0, "completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
if isinstance(e, openai.error.RateLimitError):
# Handle different types of errors
error_str = str(e).lower()
if "rate" in error_str or "limit" in error_str:
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
result["content"] = "提问太快啦,请休息一下再问我吧"
if need_retry:
time.sleep(20)
elif isinstance(e, openai.error.Timeout):
elif "timeout" in error_str:
logger.warn("[CLAUDE_API] Timeout: {}".format(e))
result["content"] = "我没有收到你的消息"
if need_retry:
time.sleep(5)
elif isinstance(e, openai.error.APIConnectionError):
elif "connection" in error_str or "network" in error_str:
logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e))
need_retry = False
result["content"] = "我连接不到你的网络"
@@ -116,7 +184,7 @@ class ClaudeAPIBot(Bot, OpenAIImage):
if need_retry:
logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, retry_count + 1)
return self.reply_text(session, retry_count + 1, tools)
else:
return result
@@ -130,3 +198,288 @@ class ClaudeAPIBot(Bot, OpenAIImage):
elif model == "claude-3.5-sonnet":
return const.CLAUDE_35_SONNET
return model
def _get_max_tokens(self, model: str) -> int:
"""
Get max_tokens for the model.
Reference from pi-mono:
- Claude 3.5/3.7: 8192
- Claude 3 Opus: 4096
- Default: 8192
"""
if model and (model.startswith("claude-3-5") or model.startswith("claude-3-7")):
return 8192
elif model and model.startswith("claude-3") and "opus" in model:
return 4096
elif model and (model.startswith("claude-sonnet-4") or model.startswith("claude-opus-4")):
return 64000
return 8192
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
"""
Call Claude API with tool support for agent integration
Args:
messages: List of messages
tools: List of tool definitions
stream: Whether to use streaming
**kwargs: Additional parameters
Returns:
Formatted response compatible with OpenAI format or generator for streaming
"""
actual_model = self._model_mapping(conf().get("model"))
# Extract system prompt from messages if present
system_prompt = kwargs.get("system", conf().get("character_desc", ""))
claude_messages = []
for msg in messages:
if msg.get("role") == "system":
system_prompt = msg["content"]
else:
claude_messages.append(msg)
request_params = {
"model": actual_model,
"max_tokens": kwargs.get("max_tokens", self._get_max_tokens(actual_model)),
"messages": claude_messages,
"stream": stream
}
if system_prompt:
request_params["system"] = system_prompt
if tools:
request_params["tools"] = tools
try:
if stream:
return self._handle_stream_response(request_params)
else:
return self._handle_sync_response(request_params)
except Exception as e:
logger.error(f"Claude API call error: {e}")
if stream:
# Return error generator for stream
def error_generator():
yield {
"error": True,
"message": str(e),
"status_code": 500
}
return error_generator()
else:
# Return error response for sync
return {
"error": True,
"message": str(e),
"status_code": 500
}
def _handle_sync_response(self, request_params):
"""Handle synchronous Claude API response"""
# Prepare headers
headers = {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
# Make HTTP request
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
response = requests.post(
f"{self.api_base}/messages",
headers=headers,
json=request_params,
proxies=proxies
)
if response.status_code != 200:
raise Exception(f"API request failed: {response.status_code} - {response.text}")
claude_response = response.json()
# Extract content blocks
text_content = ""
tool_calls = []
content_blocks = claude_response.get("content", [])
for block in content_blocks:
if block.get("type") == "text":
text_content += block.get("text", "")
elif block.get("type") == "tool_use":
tool_calls.append({
"id": block.get("id", ""),
"type": "function",
"function": {
"name": block.get("name", ""),
"arguments": json.dumps(block.get("input", {}))
}
})
# Build message in OpenAI format
message = {
"role": "assistant",
"content": text_content
}
if tool_calls:
message["tool_calls"] = tool_calls
# Format response to match OpenAI structure
usage = claude_response.get("usage", {})
formatted_response = {
"id": claude_response.get("id", ""),
"object": "chat.completion",
"created": int(time.time()),
"model": claude_response.get("model", request_params["model"]),
"choices": [
{
"index": 0,
"message": message,
"finish_reason": claude_response.get("stop_reason", "stop")
}
],
"usage": {
"prompt_tokens": usage.get("input_tokens", 0),
"completion_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
}
}
return formatted_response
def _handle_stream_response(self, request_params):
"""Handle streaming Claude API response using HTTP requests"""
# Prepare headers
headers = {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
# Add stream parameter
request_params["stream"] = True
# Track tool use state
tool_uses_map = {} # {index: {id, name, input}}
current_tool_use_index = -1
try:
# Make streaming HTTP request
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
response = requests.post(
f"{self.api_base}/messages",
headers=headers,
json=request_params,
proxies=proxies,
stream=True
)
if response.status_code != 200:
error_text = response.text
try:
error_data = json.loads(error_text)
error_msg = error_data.get("error", {}).get("message", error_text)
except:
error_msg = error_text or "Unknown error"
yield {
"error": True,
"status_code": response.status_code,
"message": error_msg
}
return
# Process streaming response
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
line = line[6:] # Remove 'data: ' prefix
if line == '[DONE]':
break
try:
event = json.loads(line)
event_type = event.get("type")
if event_type == "content_block_start":
# New content block
block = event.get("content_block", {})
if block.get("type") == "tool_use":
current_tool_use_index = event.get("index", 0)
tool_uses_map[current_tool_use_index] = {
"id": block.get("id", ""),
"name": block.get("name", ""),
"input": ""
}
elif event_type == "content_block_delta":
delta = event.get("delta", {})
delta_type = delta.get("type")
if delta_type == "text_delta":
# Text content
content = delta.get("text", "")
yield {
"id": event.get("id", ""),
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request_params["model"],
"choices": [{
"index": 0,
"delta": {"content": content},
"finish_reason": None
}]
}
elif delta_type == "input_json_delta":
# Tool input accumulation
if current_tool_use_index >= 0:
tool_uses_map[current_tool_use_index]["input"] += delta.get("partial_json", "")
elif event_type == "message_delta":
# Message complete - yield tool calls if any
if tool_uses_map:
for idx in sorted(tool_uses_map.keys()):
tool_data = tool_uses_map[idx]
yield {
"id": event.get("id", ""),
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request_params["model"],
"choices": [{
"index": 0,
"delta": {
"tool_calls": [{
"index": idx,
"id": tool_data["id"],
"type": "function",
"function": {
"name": tool_data["name"],
"arguments": tool_data["input"]
}
}]
},
"finish_reason": None
}]
}
except json.JSONDecodeError:
continue
except requests.RequestException as e:
logger.error(f"Claude streaming request error: {e}")
yield {
"error": True,
"message": f"Connection error: {str(e)}",
"status_code": 0
}
except Exception as e:
logger.error(f"Claude streaming error: {e}")
yield {
"error": True,
"message": str(e),
"status_code": 500
}

View File

@@ -6,6 +6,9 @@ Google gemini bot
"""
# encoding:utf-8
import json
import time
import requests
from bot.bot import Bot
import google.generativeai as genai
from bot.session_manager import SessionManager
@@ -29,6 +32,19 @@ class GoogleGeminiBot(Bot):
self.model = conf().get("model") or "gemini-pro"
if self.model == "gemini":
self.model = "gemini-pro"
# 支持自定义API base地址复用open_ai_api_base配置
self.api_base = conf().get("open_ai_api_base", "").strip()
if self.api_base:
# 移除末尾的斜杠
self.api_base = self.api_base.rstrip('/')
# 如果配置的是OpenAI的地址则使用默认的Gemini地址
if "api.openai.com" in self.api_base or not self.api_base:
self.api_base = "https://generativelanguage.googleapis.com"
logger.info(f"[Gemini] Using custom API base: {self.api_base}")
else:
self.api_base = "https://generativelanguage.googleapis.com"
def reply(self, query, context: Context = None) -> Reply:
try:
if context.type != ContextType.TEXT:
@@ -113,3 +129,557 @@ class GoogleGeminiBot(Bot):
elif turn == "assistant":
turn = "user"
return res
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
"""
Call Gemini API with tool support using REST API (following official docs)
Args:
messages: List of messages (OpenAI format)
tools: List of tool definitions (OpenAI/Claude format)
stream: Whether to use streaming
**kwargs: Additional parameters (system, max_tokens, temperature, etc.)
Returns:
Formatted response compatible with OpenAI format or generator for streaming
"""
try:
model_name = kwargs.get("model", self.model or "gemini-1.5-flash")
# Build REST API payload
payload = {"contents": []}
# Extract and set system instruction
system_prompt = kwargs.get("system", "")
if not system_prompt:
for msg in messages:
if msg.get("role") == "system":
system_prompt = msg["content"]
break
if system_prompt:
payload["system_instruction"] = {
"parts": [{"text": system_prompt}]
}
# Convert messages to Gemini format
for msg in messages:
role = msg.get("role")
content = msg.get("content", "")
if role == "system":
continue
# Convert role
gemini_role = "user" if role in ["user", "tool"] else "model"
# Handle different content formats
parts = []
if isinstance(content, str):
# Simple text content
parts.append({"text": content})
elif isinstance(content, list):
# List of content blocks (Claude format)
for block in content:
if not isinstance(block, dict):
if isinstance(block, str):
parts.append({"text": block})
continue
block_type = block.get("type")
if block_type == "text":
# Text block
parts.append({"text": block.get("text", "")})
elif block_type == "tool_result":
# Convert Claude tool_result to Gemini functionResponse
tool_use_id = block.get("tool_use_id")
tool_content = block.get("content", "")
# Try to parse tool content as JSON
try:
if isinstance(tool_content, str):
tool_result_data = json.loads(tool_content)
else:
tool_result_data = tool_content
except:
tool_result_data = {"result": tool_content}
# Find the tool name from previous messages
# Look for the corresponding tool_call in model's message
tool_name = None
for prev_msg in reversed(messages):
if prev_msg.get("role") == "assistant":
prev_content = prev_msg.get("content", [])
if isinstance(prev_content, list):
for prev_block in prev_content:
if isinstance(prev_block, dict) and prev_block.get("type") == "tool_use":
if prev_block.get("id") == tool_use_id:
tool_name = prev_block.get("name")
break
if tool_name:
break
# Gemini functionResponse format
parts.append({
"functionResponse": {
"name": tool_name or "unknown",
"response": tool_result_data
}
})
elif "text" in block:
# Generic text field
parts.append({"text": block["text"]})
if parts:
payload["contents"].append({
"role": gemini_role,
"parts": parts
})
# Generation config
gen_config = {}
if kwargs.get("temperature") is not None:
gen_config["temperature"] = kwargs["temperature"]
if kwargs.get("max_tokens"):
gen_config["maxOutputTokens"] = kwargs["max_tokens"]
if gen_config:
payload["generationConfig"] = gen_config
# Convert tools to Gemini format (REST API style)
if tools:
gemini_tools = self._convert_tools_to_gemini_rest_format(tools)
if gemini_tools:
payload["tools"] = gemini_tools
logger.info(f"[Gemini] Added {len(tools)} tools to request")
# Make REST API call
base_url = f"{self.api_base}/v1beta"
endpoint = f"{base_url}/models/{model_name}:generateContent"
if stream:
endpoint = f"{base_url}/models/{model_name}:streamGenerateContent?alt=sse"
headers = {
"x-goog-api-key": self.api_key,
"Content-Type": "application/json"
}
logger.debug(f"[Gemini] REST API call: {endpoint}")
response = requests.post(
endpoint,
headers=headers,
json=payload,
stream=stream,
timeout=60
)
# Check HTTP status for stream mode (for non-stream, it's checked in handler)
if stream and response.status_code != 200:
error_text = response.text
logger.error(f"[Gemini] API error ({response.status_code}): {error_text}")
def error_generator():
yield {
"error": True,
"message": f"Gemini API error: {error_text}",
"status_code": response.status_code
}
return error_generator()
if stream:
return self._handle_gemini_rest_stream_response(response, model_name)
else:
return self._handle_gemini_rest_sync_response(response, model_name)
except Exception as e:
logger.error(f"[Gemini] call_with_tools error: {e}", exc_info=True)
error_msg = str(e) # Capture error message before creating generator
if stream:
def error_generator():
yield {
"error": True,
"message": error_msg,
"status_code": 500
}
return error_generator()
else:
return {
"error": True,
"message": str(e),
"status_code": 500
}
def _convert_tools_to_gemini_rest_format(self, tools_list):
"""
Convert tools to Gemini REST API format
Handles both OpenAI and Claude/Agent formats.
Returns: [{"functionDeclarations": [...]}]
"""
function_declarations = []
for tool in tools_list:
# Extract name, description, and parameters based on format
if tool.get("type") == "function":
# OpenAI format: {"type": "function", "function": {...}}
func = tool.get("function", {})
name = func.get("name")
description = func.get("description", "")
parameters = func.get("parameters", {})
else:
# Claude/Agent format: {"name": "...", "description": "...", "input_schema": {...}}
name = tool.get("name")
description = tool.get("description", "")
parameters = tool.get("input_schema", {})
if not name:
logger.warning(f"[Gemini] Skipping tool without name: {tool}")
continue
logger.debug(f"[Gemini] Converting tool: {name}")
function_declarations.append({
"name": name,
"description": description,
"parameters": parameters
})
# All functionDeclarations must be in a single tools object (per Gemini REST API spec)
return [{
"functionDeclarations": function_declarations
}] if function_declarations else []
def _handle_gemini_rest_sync_response(self, response, model_name):
"""Handle Gemini REST API sync response and convert to OpenAI format"""
try:
if response.status_code != 200:
error_text = response.text
logger.error(f"[Gemini] API error ({response.status_code}): {error_text}")
return {
"error": True,
"message": f"Gemini API error: {error_text}",
"status_code": response.status_code
}
data = response.json()
logger.debug(f"[Gemini] Response data: {json.dumps(data, ensure_ascii=False)[:500]}")
# Extract from Gemini response format
candidates = data.get("candidates", [])
if not candidates:
logger.warning("[Gemini] No candidates in response")
return {
"error": True,
"message": "No candidates in response",
"status_code": 500
}
candidate = candidates[0]
content = candidate.get("content", {})
parts = content.get("parts", [])
logger.debug(f"[Gemini] Candidate parts count: {len(parts)}")
# Extract text and function calls
text_content = ""
tool_calls = []
for part in parts:
# Check for text
if "text" in part:
text_content += part["text"]
logger.debug(f"[Gemini] Text part: {part['text'][:100]}...")
# Check for functionCall (per REST API docs)
if "functionCall" in part:
fc = part["functionCall"]
logger.info(f"[Gemini] Function call detected: {fc.get('name')}")
tool_calls.append({
"id": f"call_{int(time.time() * 1000000)}",
"type": "function",
"function": {
"name": fc.get("name"),
"arguments": json.dumps(fc.get("args", {}))
}
})
logger.info(f"[Gemini] Response: text={len(text_content)} chars, tool_calls={len(tool_calls)}")
# Build OpenAI format response
message_dict = {
"role": "assistant",
"content": text_content or None
}
if tool_calls:
message_dict["tool_calls"] = tool_calls
return {
"id": f"chatcmpl-{time.time()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"message": message_dict,
"finish_reason": "tool_calls" if tool_calls else "stop"
}],
"usage": data.get("usageMetadata", {})
}
except Exception as e:
logger.error(f"[Gemini] sync response error: {e}", exc_info=True)
return {
"error": True,
"message": str(e),
"status_code": 500
}
def _handle_gemini_rest_stream_response(self, response, model_name):
"""Handle Gemini REST API stream response"""
try:
all_tool_calls = []
has_sent_tool_calls = False
has_content = False # Track if any content was sent
for line in response.iter_lines():
if not line:
continue
line = line.decode('utf-8')
# Skip SSE prefixes
if line.startswith('data: '):
line = line[6:]
if not line or line == '[DONE]':
continue
try:
chunk_data = json.loads(line)
logger.debug(f"[Gemini] Stream chunk: {json.dumps(chunk_data, ensure_ascii=False)[:200]}")
candidates = chunk_data.get("candidates", [])
if not candidates:
logger.debug("[Gemini] No candidates in chunk")
continue
candidate = candidates[0]
content = candidate.get("content", {})
parts = content.get("parts", [])
if not parts:
logger.debug("[Gemini] No parts in candidate content")
# Stream text content
for part in parts:
if "text" in part and part["text"]:
has_content = True
logger.debug(f"[Gemini] Streaming text: {part['text'][:50]}...")
yield {
"id": f"chatcmpl-{time.time()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": part["text"]},
"finish_reason": None
}]
}
# Collect function calls
if "functionCall" in part:
fc = part["functionCall"]
logger.debug(f"[Gemini] Function call detected: {fc.get('name')}")
all_tool_calls.append({
"index": len(all_tool_calls), # Add index to differentiate multiple tool calls
"id": f"call_{int(time.time() * 1000000)}_{len(all_tool_calls)}",
"type": "function",
"function": {
"name": fc.get("name"),
"arguments": json.dumps(fc.get("args", {}))
}
})
except json.JSONDecodeError as je:
logger.debug(f"[Gemini] JSON decode error: {je}")
continue
# Send tool calls if any were collected
if all_tool_calls and not has_sent_tool_calls:
logger.info(f"[Gemini] Stream detected {len(all_tool_calls)} tool calls")
yield {
"id": f"chatcmpl-{time.time()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"tool_calls": all_tool_calls},
"finish_reason": None
}]
}
has_sent_tool_calls = True
# Log summary
logger.info(f"[Gemini] Stream complete: has_content={has_content}, tool_calls={len(all_tool_calls)}")
# Final chunk
yield {
"id": f"chatcmpl-{time.time()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "tool_calls" if all_tool_calls else "stop"
}]
}
except Exception as e:
logger.error(f"[Gemini] stream response error: {e}", exc_info=True)
error_msg = str(e)
yield {
"error": True,
"message": error_msg,
"status_code": 500
}
def _convert_tools_to_gemini_format(self, openai_tools):
"""Convert OpenAI tool format to Gemini function declarations"""
import google.generativeai as genai
gemini_functions = []
for tool in openai_tools:
if tool.get("type") == "function":
func = tool.get("function", {})
gemini_functions.append(
genai.protos.FunctionDeclaration(
name=func.get("name"),
description=func.get("description", ""),
parameters=func.get("parameters", {})
)
)
if gemini_functions:
return [genai.protos.Tool(function_declarations=gemini_functions)]
return None
def _handle_gemini_sync_response(self, model, messages, request_params, model_name):
"""Handle synchronous Gemini API response"""
import json
response = model.generate_content(messages, **request_params)
# Extract text content and function calls
text_content = ""
tool_calls = []
if response.candidates and response.candidates[0].content:
for part in response.candidates[0].content.parts:
if hasattr(part, 'text') and part.text:
text_content += part.text
elif hasattr(part, 'function_call') and part.function_call:
# Convert Gemini function call to OpenAI format
func_call = part.function_call
tool_calls.append({
"id": f"call_{hash(func_call.name)}",
"type": "function",
"function": {
"name": func_call.name,
"arguments": json.dumps(dict(func_call.args))
}
})
# Build message in OpenAI format
message = {
"role": "assistant",
"content": text_content
}
if tool_calls:
message["tool_calls"] = tool_calls
# Format response to match OpenAI structure
formatted_response = {
"id": f"gemini_{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"message": message,
"finish_reason": "stop" if not tool_calls else "tool_calls"
}
],
"usage": {
"prompt_tokens": 0, # Gemini doesn't provide token counts in the same way
"completion_tokens": 0,
"total_tokens": 0
}
}
logger.info(f"[Gemini] call_with_tools reply, model={model_name}")
return formatted_response
def _handle_gemini_stream_response(self, model, messages, request_params, model_name):
"""Handle streaming Gemini API response"""
import json
try:
response_stream = model.generate_content(messages, stream=True, **request_params)
for chunk in response_stream:
if chunk.candidates and chunk.candidates[0].content:
for part in chunk.candidates[0].content.parts:
if hasattr(part, 'text') and part.text:
# Text content
yield {
"id": f"gemini_{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": part.text},
"finish_reason": None
}]
}
elif hasattr(part, 'function_call') and part.function_call:
# Function call
func_call = part.function_call
yield {
"id": f"gemini_{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"delta": {
"tool_calls": [{
"index": 0,
"id": f"call_{hash(func_call.name)}",
"type": "function",
"function": {
"name": func_call.name,
"arguments": json.dumps(dict(func_call.args))
}
}]
},
"finish_reason": None
}]
}
except Exception as e:
logger.error(f"[Gemini] stream response error: {e}")
yield {
"error": True,
"message": str(e),
"status_code": 500
}

View File

@@ -6,6 +6,7 @@ import time
import requests
import config
from bot.bot import Bot
from bot.openai_compatible_bot import OpenAICompatibleBot
from bot.chatgpt.chat_gpt_session import ChatGPTSession
from bot.session_manager import SessionManager
from bridge.context import Context, ContextType
@@ -17,7 +18,7 @@ from common import memory, utils
import base64
import os
class LinkAIBot(Bot):
class LinkAIBot(Bot, OpenAICompatibleBot):
# authentication failed
AUTH_FAILED_CODE = 401
NO_QUOTA_CODE = 406
@@ -26,6 +27,18 @@ class LinkAIBot(Bot):
super().__init__()
self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
self.args = {}
def get_api_config(self):
"""Get API configuration for OpenAI-compatible base class"""
return {
'api_key': conf().get("open_ai_api_key"), # LinkAI uses OpenAI-compatible key
'api_base': conf().get("open_ai_api_base", "https://api.link-ai.tech/v1"),
'model': conf().get("model", "gpt-3.5-turbo"),
'default_temperature': conf().get("temperature", 0.9),
'default_top_p': conf().get("top_p", 1.0),
'default_frequency_penalty': conf().get("frequency_penalty", 0.0),
'default_presence_penalty': conf().get("presence_penalty", 0.0),
}
def reply(self, query, context: Context = None) -> Reply:
if context.type == ContextType.TEXT:
@@ -473,3 +486,150 @@ class LinkAISession(ChatGPTSession):
self.messages.pop(i - 1)
return self.calc_tokens()
return cur_tokens
# Add call_with_tools method to LinkAIBot class
def _linkai_call_with_tools(self, messages, tools=None, stream=False, **kwargs):
"""
Call LinkAI API with tool support for agent integration
LinkAI is fully compatible with OpenAI's tool calling format
Args:
messages: List of messages
tools: List of tool definitions (OpenAI format)
stream: Whether to use streaming
**kwargs: Additional parameters (max_tokens, temperature, etc.)
Returns:
Formatted response in OpenAI format or generator for streaming
"""
try:
# Build request parameters (LinkAI uses OpenAI-compatible format)
body = {
"messages": messages,
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
"temperature": kwargs.get("temperature", conf().get("temperature", 0.9)),
"top_p": kwargs.get("top_p", conf().get("top_p", 1)),
"frequency_penalty": kwargs.get("frequency_penalty", conf().get("frequency_penalty", 0.0)),
"presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)),
"stream": stream
}
# Add max_tokens if specified
if kwargs.get("max_tokens"):
body["max_tokens"] = kwargs["max_tokens"]
# Add app_code if provided
app_code = kwargs.get("app_code", conf().get("linkai_app_code"))
if app_code:
body["app_code"] = app_code
# Add tools if provided (OpenAI-compatible format)
if tools:
body["tools"] = tools
body["tool_choice"] = kwargs.get("tool_choice", "auto")
# Prepare headers
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
if stream:
return self._handle_linkai_stream_response(base_url, headers, body)
else:
return self._handle_linkai_sync_response(base_url, headers, body)
except Exception as e:
logger.error(f"[LinkAI] call_with_tools error: {e}")
if stream:
def error_generator():
yield {
"error": True,
"message": str(e),
"status_code": 500
}
return error_generator()
else:
return {
"error": True,
"message": str(e),
"status_code": 500
}
def _handle_linkai_sync_response(self, base_url, headers, body):
"""Handle synchronous LinkAI API response"""
try:
res = requests.post(
url=base_url + "/v1/chat/completions",
json=body,
headers=headers,
timeout=conf().get("request_timeout", 180)
)
if res.status_code == 200:
response = res.json()
logger.info(f"[LinkAI] call_with_tools reply, model={response.get('model')}, "
f"total_tokens={response.get('usage', {}).get('total_tokens', 0)}")
# LinkAI response is already in OpenAI-compatible format
return response
else:
error_data = res.json()
error_msg = error_data.get("error", {}).get("message", "Unknown error")
raise Exception(f"LinkAI API error: {res.status_code} - {error_msg}")
except Exception as e:
logger.error(f"[LinkAI] sync response error: {e}")
raise
def _handle_linkai_stream_response(self, base_url, headers, body):
"""Handle streaming LinkAI API response"""
try:
res = requests.post(
url=base_url + "/v1/chat/completions",
json=body,
headers=headers,
timeout=conf().get("request_timeout", 180),
stream=True
)
if res.status_code != 200:
error_text = res.text
try:
error_data = json.loads(error_text)
error_msg = error_data.get("error", {}).get("message", error_text)
except:
error_msg = error_text or "Unknown error"
yield {
"error": True,
"status_code": res.status_code,
"message": error_msg
}
return
# Process streaming response (OpenAI-compatible SSE format)
for line in res.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
line = line[6:] # Remove 'data: ' prefix
if line == '[DONE]':
break
try:
chunk = json.loads(line)
yield chunk
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"[LinkAI] stream response error: {e}")
yield {
"error": True,
"message": str(e),
"status_code": 500
}
# Attach methods to LinkAIBot class
LinkAIBot.call_with_tools = _linkai_call_with_tools
LinkAIBot._handle_linkai_sync_response = _handle_linkai_sync_response
LinkAIBot._handle_linkai_stream_response = _handle_linkai_stream_response

View File

@@ -6,6 +6,7 @@ import openai
import openai.error
from bot.bot import Bot
from bot.openai_compatible_bot import OpenAICompatibleBot
from bot.openai.open_ai_image import OpenAIImage
from bot.openai.open_ai_session import OpenAISession
from bot.session_manager import SessionManager
@@ -18,7 +19,7 @@ user_session = dict()
# OpenAI对话模型API (可用)
class OpenAIBot(Bot, OpenAIImage):
class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot):
def __init__(self):
super().__init__()
openai.api_key = conf().get("open_ai_api_key")
@@ -40,6 +41,18 @@ class OpenAIBot(Bot, OpenAIImage):
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
"stop": ["\n\n\n"],
}
def get_api_config(self):
"""Get API configuration for OpenAI-compatible base class"""
return {
'api_key': conf().get("open_ai_api_key"),
'api_base': conf().get("open_ai_api_base"),
'model': conf().get("model", "text-davinci-003"),
'default_temperature': conf().get("temperature", 0.9),
'default_top_p': conf().get("top_p", 1.0),
'default_frequency_penalty': conf().get("frequency_penalty", 0.0),
'default_presence_penalty': conf().get("presence_penalty", 0.0),
}
def reply(self, query, context=None):
# acquire reply content
@@ -120,3 +133,98 @@ class OpenAIBot(Bot, OpenAIImage):
return self.reply_text(session, retry_count + 1)
else:
return result
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
"""
Call OpenAI API with tool support for agent integration
Note: This bot uses the old Completion API which doesn't support tools.
For tool support, use ChatGPTBot instead.
This method converts to ChatCompletion API when tools are provided.
Args:
messages: List of messages
tools: List of tool definitions (OpenAI format)
stream: Whether to use streaming
**kwargs: Additional parameters
Returns:
Formatted response in OpenAI format or generator for streaming
"""
try:
# The old Completion API doesn't support tools
# We need to use ChatCompletion API instead
logger.info("[OPEN_AI] Using ChatCompletion API for tool support")
# Build request parameters for ChatCompletion
request_params = {
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
"messages": messages,
"temperature": kwargs.get("temperature", conf().get("temperature", 0.9)),
"top_p": kwargs.get("top_p", 1),
"frequency_penalty": kwargs.get("frequency_penalty", conf().get("frequency_penalty", 0.0)),
"presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)),
"stream": stream
}
# Add max_tokens if specified
if kwargs.get("max_tokens"):
request_params["max_tokens"] = kwargs["max_tokens"]
# Add tools if provided
if tools:
request_params["tools"] = tools
request_params["tool_choice"] = kwargs.get("tool_choice", "auto")
# Make API call using ChatCompletion
if stream:
return self._handle_stream_response(request_params)
else:
return self._handle_sync_response(request_params)
except Exception as e:
logger.error(f"[OPEN_AI] call_with_tools error: {e}")
if stream:
def error_generator():
yield {
"error": True,
"message": str(e),
"status_code": 500
}
return error_generator()
else:
return {
"error": True,
"message": str(e),
"status_code": 500
}
def _handle_sync_response(self, request_params):
"""Handle synchronous OpenAI ChatCompletion API response"""
try:
response = openai.ChatCompletion.create(**request_params)
logger.info(f"[OPEN_AI] call_with_tools reply, model={response.get('model')}, "
f"total_tokens={response.get('usage', {}).get('total_tokens', 0)}")
return response
except Exception as e:
logger.error(f"[OPEN_AI] sync response error: {e}")
raise
def _handle_stream_response(self, request_params):
"""Handle streaming OpenAI ChatCompletion API response"""
try:
stream = openai.ChatCompletion.create(**request_params)
for chunk in stream:
yield chunk
except Exception as e:
logger.error(f"[OPEN_AI] stream response error: {e}")
yield {
"error": True,
"message": str(e),
"status_code": 500
}

View File

@@ -1,7 +1,7 @@
import time
import openai
import openai.error
from bot.openai.openai_compat import RateLimitError
from common.log import logger
from common.token_bucket import TokenBucket
@@ -30,7 +30,7 @@ class OpenAIImage(object):
image_url = response["data"][0]["url"]
logger.info("[OPEN_AI] image_url={}".format(image_url))
return True, image_url
except openai.error.RateLimitError as e:
except RateLimitError as e:
logger.warn(e)
if retry_count < 1:
time.sleep(5)

102
bot/openai/openai_compat.py Normal file
View File

@@ -0,0 +1,102 @@
"""
OpenAI compatibility layer for different versions.
This module provides a compatibility layer between OpenAI library versions:
- OpenAI < 1.0 (old API with openai.error module)
- OpenAI >= 1.0 (new API with direct exception imports)
"""
try:
# Try new OpenAI >= 1.0 API
from openai import (
OpenAIError,
RateLimitError,
APIError,
APIConnectionError,
AuthenticationError,
APITimeoutError,
BadRequestError,
)
# Create a mock error module for backward compatibility
class ErrorModule:
OpenAIError = OpenAIError
RateLimitError = RateLimitError
APIError = APIError
APIConnectionError = APIConnectionError
AuthenticationError = AuthenticationError
Timeout = APITimeoutError # Renamed in new version
InvalidRequestError = BadRequestError # Renamed in new version
error = ErrorModule()
# Also export with new names
Timeout = APITimeoutError
InvalidRequestError = BadRequestError
except ImportError:
# Fall back to old OpenAI < 1.0 API
try:
import openai.error as error
# Export individual exceptions for direct import
OpenAIError = error.OpenAIError
RateLimitError = error.RateLimitError
APIError = error.APIError
APIConnectionError = error.APIConnectionError
AuthenticationError = error.AuthenticationError
InvalidRequestError = error.InvalidRequestError
Timeout = error.Timeout
BadRequestError = error.InvalidRequestError # Alias
APITimeoutError = error.Timeout # Alias
except (ImportError, AttributeError):
# Neither version works, create dummy classes
class OpenAIError(Exception):
pass
class RateLimitError(OpenAIError):
pass
class APIError(OpenAIError):
pass
class APIConnectionError(OpenAIError):
pass
class AuthenticationError(OpenAIError):
pass
class InvalidRequestError(OpenAIError):
pass
class Timeout(OpenAIError):
pass
BadRequestError = InvalidRequestError
APITimeoutError = Timeout
# Create error module
class ErrorModule:
OpenAIError = OpenAIError
RateLimitError = RateLimitError
APIError = APIError
APIConnectionError = APIConnectionError
AuthenticationError = AuthenticationError
InvalidRequestError = InvalidRequestError
Timeout = Timeout
error = ErrorModule()
# Export all for easy import
__all__ = [
'error',
'OpenAIError',
'RateLimitError',
'APIError',
'APIConnectionError',
'AuthenticationError',
'InvalidRequestError',
'Timeout',
'BadRequestError',
'APITimeoutError',
]

View File

@@ -0,0 +1,278 @@
# encoding:utf-8
"""
OpenAI-Compatible Bot Base Class
Provides a common implementation for bots that are compatible with OpenAI's API format.
This includes: OpenAI, LinkAI, Azure OpenAI, and many third-party providers.
"""
import json
import openai
from common.log import logger
class OpenAICompatibleBot:
"""
Base class for OpenAI-compatible bots.
Provides common tool calling implementation that can be inherited by:
- ChatGPTBot
- LinkAIBot
- OpenAIBot
- AzureChatGPTBot
- Other OpenAI-compatible providers
Subclasses only need to override get_api_config() to provide their specific API settings.
"""
def get_api_config(self):
"""
Get API configuration for this bot.
Subclasses should override this to provide their specific config.
Returns:
dict: {
'api_key': str,
'api_base': str (optional),
'model': str,
'default_temperature': float,
'default_top_p': float,
'default_frequency_penalty': float,
'default_presence_penalty': float,
}
"""
raise NotImplementedError("Subclasses must implement get_api_config()")
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
"""
Call OpenAI-compatible API with tool support for agent integration
This method handles:
1. Format conversion (Claude format → OpenAI format)
2. System prompt injection
3. API calling with proper configuration
4. Error handling
Args:
messages: List of messages (may be in Claude format from agent)
tools: List of tool definitions (may be in Claude format from agent)
stream: Whether to use streaming
**kwargs: Additional parameters (max_tokens, temperature, system, etc.)
Returns:
Formatted response in OpenAI format or generator for streaming
"""
try:
# Get API configuration from subclass
api_config = self.get_api_config()
# Convert messages from Claude format to OpenAI format
messages = self._convert_messages_to_openai_format(messages)
# Convert tools from Claude format to OpenAI format
if tools:
tools = self._convert_tools_to_openai_format(tools)
# Handle system prompt (OpenAI uses system message, Claude uses separate parameter)
system_prompt = kwargs.get('system')
if system_prompt:
# Add system message at the beginning if not already present
if not messages or messages[0].get('role') != 'system':
messages = [{"role": "system", "content": system_prompt}] + messages
else:
# Replace existing system message
messages[0] = {"role": "system", "content": system_prompt}
# Build request parameters
request_params = {
"model": kwargs.get("model", api_config.get('model', 'gpt-3.5-turbo')),
"messages": messages,
"temperature": kwargs.get("temperature", api_config.get('default_temperature', 0.9)),
"top_p": kwargs.get("top_p", api_config.get('default_top_p', 1.0)),
"frequency_penalty": kwargs.get("frequency_penalty", api_config.get('default_frequency_penalty', 0.0)),
"presence_penalty": kwargs.get("presence_penalty", api_config.get('default_presence_penalty', 0.0)),
"stream": stream
}
# Add max_tokens if specified
if kwargs.get("max_tokens"):
request_params["max_tokens"] = kwargs["max_tokens"]
# Add tools if provided
if tools:
request_params["tools"] = tools
request_params["tool_choice"] = kwargs.get("tool_choice", "auto")
# Make API call with proper configuration
api_key = api_config.get('api_key')
api_base = api_config.get('api_base')
if stream:
return self._handle_stream_response(request_params, api_key, api_base)
else:
return self._handle_sync_response(request_params, api_key, api_base)
except Exception as e:
error_msg = str(e)
logger.error(f"[{self.__class__.__name__}] call_with_tools error: {error_msg}")
if stream:
def error_generator():
yield {
"error": True,
"message": error_msg,
"status_code": 500
}
return error_generator()
else:
return {
"error": True,
"message": error_msg,
"status_code": 500
}
def _handle_sync_response(self, request_params, api_key, api_base):
"""Handle synchronous OpenAI API response"""
try:
# Build kwargs with explicit API configuration
kwargs = dict(request_params)
if api_key:
kwargs["api_key"] = api_key
if api_base:
kwargs["api_base"] = api_base
response = openai.ChatCompletion.create(**kwargs)
return response
except Exception as e:
logger.error(f"[{self.__class__.__name__}] sync response error: {e}")
return {
"error": True,
"message": str(e),
"status_code": 500
}
def _handle_stream_response(self, request_params, api_key, api_base):
"""Handle streaming OpenAI API response"""
try:
# Build kwargs with explicit API configuration
kwargs = dict(request_params)
if api_key:
kwargs["api_key"] = api_key
if api_base:
kwargs["api_base"] = api_base
stream = openai.ChatCompletion.create(**kwargs)
# Stream chunks to caller
for chunk in stream:
yield chunk
except Exception as e:
logger.error(f"[{self.__class__.__name__}] stream response error: {e}")
yield {
"error": True,
"message": str(e),
"status_code": 500
}
def _convert_tools_to_openai_format(self, tools):
"""
Convert tools from Claude format to OpenAI format
Claude format: {name, description, input_schema}
OpenAI format: {type: "function", function: {name, description, parameters}}
"""
if not tools:
return None
openai_tools = []
for tool in tools:
# Check if already in OpenAI format
if 'type' in tool and tool['type'] == 'function':
openai_tools.append(tool)
else:
# Convert from Claude format
openai_tools.append({
"type": "function",
"function": {
"name": tool.get("name"),
"description": tool.get("description"),
"parameters": tool.get("input_schema", {})
}
})
return openai_tools
def _convert_messages_to_openai_format(self, messages):
"""
Convert messages from Claude format to OpenAI format
Claude uses content blocks with types like 'tool_use', 'tool_result'
OpenAI uses 'tool_calls' in assistant messages and 'tool' role for results
"""
if not messages:
return []
openai_messages = []
for msg in messages:
role = msg.get("role")
content = msg.get("content")
# Handle string content (already in correct format)
if isinstance(content, str):
openai_messages.append(msg)
continue
# Handle list content (Claude format with content blocks)
if isinstance(content, list):
# Check if this is a tool result message (user role with tool_result blocks)
if role == "user" and any(block.get("type") == "tool_result" for block in content):
# Convert each tool_result block to a separate tool message
for block in content:
if block.get("type") == "tool_result":
openai_messages.append({
"role": "tool",
"tool_call_id": block.get("tool_use_id"),
"content": block.get("content", "")
})
# Check if this is an assistant message with tool_use blocks
elif role == "assistant":
# Separate text content and tool_use blocks
text_parts = []
tool_calls = []
for block in content:
if block.get("type") == "text":
text_parts.append(block.get("text", ""))
elif block.get("type") == "tool_use":
tool_calls.append({
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": json.dumps(block.get("input", {}))
}
})
# Build OpenAI format assistant message
openai_msg = {
"role": "assistant",
"content": " ".join(text_parts) if text_parts else None
}
if tool_calls:
openai_msg["tool_calls"] = tool_calls
openai_messages.append(openai_msg)
else:
# Other list content, keep as is
openai_messages.append(msg)
else:
# Other formats, keep as is
openai_messages.append(msg)
return openai_messages

378
bridge/agent_bridge.py Normal file
View File

@@ -0,0 +1,378 @@
"""
Agent Bridge - Integrates Agent system with existing COW bridge
"""
from typing import Optional, List
from agent.protocol import Agent, LLMModel, LLMRequest
from bot.openai_compatible_bot import OpenAICompatibleBot
from bridge.bridge import Bridge
from bridge.context import Context
from bridge.reply import Reply, ReplyType
from common import const
from common.log import logger
def add_openai_compatible_support(bot_instance):
"""
Dynamically add OpenAI-compatible tool calling support to a bot instance.
This allows any bot to gain tool calling capability without modifying its code,
as long as it uses OpenAI-compatible API format.
"""
if hasattr(bot_instance, 'call_with_tools'):
# Bot already has tool calling support
return bot_instance
# Create a temporary mixin class that combines the bot with OpenAI compatibility
class EnhancedBot(bot_instance.__class__, OpenAICompatibleBot):
"""Dynamically enhanced bot with OpenAI-compatible tool calling"""
def get_api_config(self):
"""
Infer API config from common configuration patterns.
Most OpenAI-compatible bots use similar configuration.
"""
from config import conf
return {
'api_key': conf().get("open_ai_api_key"),
'api_base': conf().get("open_ai_api_base"),
'model': conf().get("model", "gpt-3.5-turbo"),
'default_temperature': conf().get("temperature", 0.9),
'default_top_p': conf().get("top_p", 1.0),
'default_frequency_penalty': conf().get("frequency_penalty", 0.0),
'default_presence_penalty': conf().get("presence_penalty", 0.0),
}
# Change the bot's class to the enhanced version
bot_instance.__class__ = EnhancedBot
logger.info(
f"[AgentBridge] Enhanced {bot_instance.__class__.__bases__[0].__name__} with OpenAI-compatible tool calling")
return bot_instance
class AgentLLMModel(LLMModel):
"""
LLM Model adapter that uses COW's existing bot infrastructure
"""
def __init__(self, bridge: Bridge, bot_type: str = "chat"):
# Get model name directly from config
from config import conf
model_name = conf().get("model", const.GPT_41)
super().__init__(model=model_name)
self.bridge = bridge
self.bot_type = bot_type
self._bot = None
@property
def bot(self):
"""Lazy load the bot and enhance it with tool calling if needed"""
if self._bot is None:
self._bot = self.bridge.get_bot(self.bot_type)
# Automatically add tool calling support if not present
self._bot = add_openai_compatible_support(self._bot)
return self._bot
def call(self, request: LLMRequest):
"""
Call the model using COW's bot infrastructure
"""
try:
# For non-streaming calls, we'll use the existing reply method
# This is a simplified implementation
if hasattr(self.bot, 'call_with_tools'):
# Use tool-enabled call if available
kwargs = {
'messages': request.messages,
'tools': getattr(request, 'tools', None),
'stream': False
}
# Only pass max_tokens if it's explicitly set
if request.max_tokens is not None:
kwargs['max_tokens'] = request.max_tokens
response = self.bot.call_with_tools(**kwargs)
return self._format_response(response)
else:
# Fallback to regular call
# This would need to be implemented based on your specific needs
raise NotImplementedError("Regular call not implemented yet")
except Exception as e:
logger.error(f"AgentLLMModel call error: {e}")
raise
def call_stream(self, request: LLMRequest):
"""
Call the model with streaming using COW's bot infrastructure
"""
try:
if hasattr(self.bot, 'call_with_tools'):
# Use tool-enabled streaming call if available
# Ensure max_tokens is an integer, use default if None
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
# Extract system prompt if present
system_prompt = getattr(request, 'system', None)
# Build kwargs for call_with_tools
kwargs = {
'messages': request.messages,
'tools': getattr(request, 'tools', None),
'stream': True,
'max_tokens': max_tokens
}
# Add system prompt if present
if system_prompt:
kwargs['system'] = system_prompt
stream = self.bot.call_with_tools(**kwargs)
# Convert stream format to our expected format
for chunk in stream:
yield self._format_stream_chunk(chunk)
else:
bot_type = type(self.bot).__name__
raise NotImplementedError(f"Bot {bot_type} does not support call_with_tools. Please add the method.")
except Exception as e:
logger.error(f"AgentLLMModel call_stream error: {e}", exc_info=True)
raise
def _format_response(self, response):
"""Format Claude response to our expected format"""
# This would need to be implemented based on Claude's response format
return response
def _format_stream_chunk(self, chunk):
"""Format Claude stream chunk to our expected format"""
# This would need to be implemented based on Claude's stream format
return chunk
class AgentBridge:
"""
Bridge class that integrates single super Agent with COW
"""
def __init__(self, bridge: Bridge):
self.bridge = bridge
self.agent: Optional[Agent] = None
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
"""
Create the super agent with COW integration
Args:
system_prompt: System prompt
tools: List of tools (optional)
**kwargs: Additional agent parameters
Returns:
Agent instance
"""
# Create LLM model that uses COW's bot infrastructure
model = AgentLLMModel(self.bridge)
# Default tools if none provided
if tools is None:
# Use ToolManager to load all available tools
from agent.tools import ToolManager
tool_manager = ToolManager()
tool_manager.load_tools()
tools = []
for tool_name in tool_manager.tool_classes.keys():
try:
tool = tool_manager.create_tool(tool_name)
if tool:
tools.append(tool)
except Exception as e:
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
# Create the single super agent
self.agent = Agent(
system_prompt=system_prompt,
description=kwargs.get("description", "AI Super Agent"),
model=model,
tools=tools,
max_steps=kwargs.get("max_steps", 15),
output_mode=kwargs.get("output_mode", "logger"),
workspace_dir=kwargs.get("workspace_dir"), # Pass workspace for skills loading
enable_skills=kwargs.get("enable_skills", True), # Enable skills by default
memory_manager=kwargs.get("memory_manager"), # Pass memory manager
max_context_tokens=kwargs.get("max_context_tokens"),
context_reserve_tokens=kwargs.get("context_reserve_tokens")
)
# Log skill loading details
if self.agent.skill_manager:
logger.info(f"[AgentBridge] SkillManager initialized:")
logger.info(f"[AgentBridge] - Managed dir: {self.agent.skill_manager.managed_skills_dir}")
logger.info(f"[AgentBridge] - Workspace dir: {self.agent.skill_manager.workspace_dir}")
logger.info(f"[AgentBridge] - Total skills: {len(self.agent.skill_manager.skills)}")
for skill_name in self.agent.skill_manager.skills.keys():
logger.info(f"[AgentBridge] * {skill_name}")
return self.agent
def get_agent(self) -> Optional[Agent]:
"""Get the super agent, create if not exists"""
if self.agent is None:
self._init_default_agent()
return self.agent
def _init_default_agent(self):
"""Initialize default super agent with new prompt system"""
from config import conf
import os
# Get workspace from config
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
# Initialize workspace and create template files
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
workspace_files = ensure_workspace(workspace_root, create_templates=True)
logger.info(f"[AgentBridge] Workspace initialized at: {workspace_root}")
# Setup memory system
memory_manager = None
memory_tools = []
try:
# Try to initialize memory system
from agent.memory import MemoryManager, MemoryConfig
from agent.tools import MemorySearchTool, MemoryGetTool
memory_config = MemoryConfig(
workspace_root=workspace_root,
embedding_provider="local", # Use local embedding (no API key needed)
embedding_model="all-MiniLM-L6-v2"
)
# Create memory manager with the config
memory_manager = MemoryManager(memory_config)
# Create memory tools
memory_tools = [
MemorySearchTool(memory_manager),
MemoryGetTool(memory_manager)
]
logger.info(f"[AgentBridge] Memory system initialized")
except Exception as e:
logger.warning(f"[AgentBridge] Memory system not available: {e}")
logger.info("[AgentBridge] Continuing without memory features")
# Use ToolManager to dynamically load all available tools
from agent.tools import ToolManager
tool_manager = ToolManager()
tool_manager.load_tools()
# Create tool instances for all available tools
tools = []
file_config = {
"cwd": workspace_root,
"memory_manager": memory_manager
} if memory_manager else {"cwd": workspace_root}
for tool_name in tool_manager.tool_classes.keys():
try:
tool = tool_manager.create_tool(tool_name)
if tool:
# Apply workspace config to file operation tools
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls']:
tool.config = file_config
tool.cwd = file_config.get("cwd", tool.cwd if hasattr(tool, 'cwd') else None)
if 'memory_manager' in file_config:
tool.memory_manager = file_config['memory_manager']
tools.append(tool)
logger.debug(f"[AgentBridge] Loaded tool: {tool_name}")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
# Add memory tools
if memory_tools:
tools.extend(memory_tools)
logger.info(f"[AgentBridge] Added {len(memory_tools)} memory tools")
logger.info(f"[AgentBridge] Loaded {len(tools)} tools: {[t.name for t in tools]}")
# Load context files (SOUL.md, USER.md, etc.)
context_files = load_context_files(workspace_root)
logger.info(f"[AgentBridge] Loaded {len(context_files)} context files: {[f.path for f in context_files]}")
# Build system prompt using new prompt builder
prompt_builder = PromptBuilder(
workspace_dir=workspace_root,
language="zh"
)
# Get runtime info
runtime_info = {
"model": conf().get("model", "unknown"),
"workspace": workspace_root,
"channel": "web" # TODO: get from actual channel, default to "web" to hide if not specified
}
system_prompt = prompt_builder.build(
tools=tools,
context_files=context_files,
memory_manager=memory_manager,
runtime_info=runtime_info
)
logger.info("[AgentBridge] System prompt built successfully")
# Create agent with configured tools and workspace
agent = self.create_agent(
system_prompt=system_prompt,
tools=tools,
max_steps=50,
output_mode="logger",
workspace_dir=workspace_root, # Pass workspace to agent for skills loading
enable_skills=True # Enable skills auto-loading
)
# Attach memory manager to agent if available
if memory_manager:
agent.memory_manager = memory_manager
logger.info(f"[AgentBridge] Memory manager attached to agent")
def agent_reply(self, query: str, context: Context = None,
on_event=None, clear_history: bool = False) -> Reply:
"""
Use super agent to reply to a query
Args:
query: User query
context: COW context (optional)
on_event: Event callback (optional)
clear_history: Whether to clear conversation history
Returns:
Reply object
"""
try:
# Get agent (will auto-initialize if needed)
agent = self.get_agent()
if not agent:
return Reply(ReplyType.ERROR, "Failed to initialize super agent")
# Use agent's run_stream method
response = agent.run_stream(
user_message=query,
on_event=on_event,
clear_history=clear_history
)
return Reply(ReplyType.TEXT, response)
except Exception as e:
logger.error(f"Agent reply error: {e}")
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")

View File

@@ -23,7 +23,7 @@ class Bridge(object):
if bot_type:
self.btype["chat"] = bot_type
else:
model_type = conf().get("model") or const.GPT35
model_type = conf().get("model") or const.GPT_41_MINI
if model_type in ["text-davinci-003"]:
self.btype["chat"] = const.OPEN_AI
if conf().get("use_azure_chatgpt", False):
@@ -64,6 +64,7 @@ class Bridge(object):
self.bots = {}
self.chat_bots = {}
self._agent_bridge = None
# 模型对应的接口
def get_bot(self, typename):
@@ -104,3 +105,29 @@ class Bridge(object):
重置bot路由
"""
self.__init__()
def get_agent_bridge(self):
"""
Get agent bridge for agent-based conversations
"""
if self._agent_bridge is None:
from bridge.agent_bridge import AgentBridge
self._agent_bridge = AgentBridge(self)
return self._agent_bridge
def fetch_agent_reply(self, query: str, context: Context = None,
on_event=None, clear_history: bool = False) -> Reply:
"""
Use super agent to handle the query
Args:
query: User query
context: Context object
on_event: Event callback for streaming
clear_history: Whether to clear conversation history
Returns:
Reply object
"""
agent_bridge = self.get_agent_bridge()
return agent_bridge.agent_reply(query, context, on_event, clear_history)

View File

@@ -5,6 +5,8 @@ Message sending channel abstract class
from bridge.bridge import Bridge
from bridge.context import Context
from bridge.reply import *
from common.log import logger
from config import conf
class Channel(object):
@@ -35,7 +37,30 @@ class Channel(object):
raise NotImplementedError
def build_reply_content(self, query, context: Context = None) -> Reply:
return Bridge().fetch_reply_content(query, context)
"""
Build reply content, using agent if enabled in config
"""
# Check if agent mode is enabled
use_agent = conf().get("agent", False)
if use_agent:
try:
logger.info("[Channel] Using agent mode")
# Use agent bridge to handle the query
return Bridge().fetch_agent_reply(
query=query,
context=context,
on_event=None,
clear_history=False
)
except Exception as e:
logger.error(f"[Channel] Agent mode failed, fallback to normal mode: {e}")
# Fallback to normal mode if agent fails
return Bridge().fetch_reply_content(query, context)
else:
# Normal mode
return Bridge().fetch_reply_content(query, context)
def build_voice_to_text(self, voice_file) -> Reply:
return Bridge().fetch_voice_to_text(voice_file)

167
channel/feishu/README.md Normal file
View File

@@ -0,0 +1,167 @@
# 飞书Channel使用说明
飞书Channel支持两种事件接收模式可以根据部署环境灵活选择。
## 模式对比
| 模式 | 适用场景 | 优点 | 缺点 |
|------|---------|------|------|
| **webhook** | 生产环境 | 稳定可靠,官方推荐 | 需要公网IP或域名 |
| **websocket** | 本地开发 | 无需公网IP开发便捷 | 需要额外依赖 |
## 配置说明
### 基础配置
`config.json` 中添加以下配置:
```json
{
"channel_type": "feishu",
"feishu_app_id": "cli_xxxxx",
"feishu_app_secret": "your_app_secret",
"feishu_token": "your_verification_token",
"feishu_bot_name": "你的机器人名称",
"feishu_event_mode": "webhook",
"feishu_port": 9891
}
```
### 配置项说明
- `feishu_app_id`: 飞书应用的App ID
- `feishu_app_secret`: 飞书应用的App Secret
- `feishu_token`: 事件订阅的Verification Token
- `feishu_bot_name`: 机器人名称(用于群聊@判断)
- `feishu_event_mode`: 事件接收模式,可选值:
- `"websocket"`: 长连接模式(默认)
- `"webhook"`: HTTP服务器模式
- `feishu_port`: webhook模式下的HTTP服务端口(默认9891)
## 模式一: Webhook模式(推荐生产环境)
### 1. 配置
```json
{
"feishu_event_mode": "webhook",
"feishu_port": 9891
}
```
### 2. 启动服务
```bash
python3 app.py
```
服务将在 `http://0.0.0.0:9891` 启动。
### 3. 配置飞书应用
1. 登录[飞书开放平台](https://open.feishu.cn/)
2. 进入应用详情 -> 事件订阅
3. 选择 **将事件发送至开发者服务器**
4. 填写请求地址: `http://your-domain:9891/`
5. 添加事件: `im.message.receive_v1` (接收消息v2.0)
6. 保存配置
### 4. 注意事项
- 需要有公网IP或域名
- 确保防火墙开放对应端口
- 建议使用HTTPS(需要配置反向代理)
## 模式二: WebSocket模式(推荐本地开发)
### 1. 安装依赖
```bash
pip install lark-oapi
```
### 2. 配置
```json
{
"feishu_event_mode": "websocket"
}
```
### 3. 启动服务
```bash
python3 app.py
```
程序将自动建立与飞书开放平台的长连接。
### 4. 配置飞书应用
1. 登录[飞书开放平台](https://open.feishu.cn/)
2. 进入应用详情 -> 事件订阅
3. 选择 **使用长连接接收事件**
4. 添加事件: `im.message.receive_v1` (接收消息v2.0)
5. 保存配置
### 5. 注意事项
- 无需公网IP
- 需要能访问公网(建立WebSocket连接)
- 每个应用最多50个连接
- 集群模式下消息随机分发到一个客户端
## 平滑迁移
从webhook模式切换到websocket模式(或反向切换):
1. 修改 `config.json` 中的 `feishu_event_mode`
2. 如果切换到websocket模式安装 `lark-oapi` 依赖
3. 重启服务
4. 在飞书开放平台修改事件订阅方式
**重要**: 同一时间只能使用一种模式,否则会导致消息重复接收。
## 消息去重机制
两种模式都使用相同的消息去重机制:
- 使用 `ExpiredDict` 存储已处理的消息ID
- 过期时间: 7.1小时
- 确保消息不会重复处理
## 故障排查
### WebSocket模式连接失败
```
[FeiShu] lark_oapi not installed
```
**解决**: 安装依赖 `pip install lark-oapi`
### Webhook模式端口被占用
```
Address already in use
```
**解决**: 修改 `feishu_port` 配置或关闭占用端口的进程
### 收不到消息
1. 检查飞书应用的事件订阅配置
2. 确认已添加 `im.message.receive_v1` 事件
3. 检查应用权限: 需要 `im:message` 权限
4. 查看日志中的错误信息
## 开发建议
- **本地开发**: 使用websocket模式快速迭代
- **测试环境**: 可以使用webhook模式 + 内网穿透工具(如ngrok)
- **生产环境**: 使用webhook模式配置正式域名和HTTPS
## 参考文档
- [飞书开放平台 - 事件订阅](https://open.feishu.cn/document/ukTMukTMukTM/uUTNz4SN1MjL1UzM)
- [飞书SDK - Python](https://github.com/larksuite/oapi-sdk-python)

View File

@@ -1,48 +1,80 @@
"""
飞书通道接入
支持两种事件接收模式:
1. webhook模式: 通过HTTP服务器接收事件(需要公网IP)
2. websocket模式: 通过长连接接收事件(本地开发友好)
通过配置项 feishu_event_mode 选择模式: "webhook""websocket"
@author Saboteur7
@Date 2023/11/19
"""
import json
import os
import threading
# -*- coding=utf-8 -*-
import uuid
import requests
import web
from channel.feishu.feishu_message import FeishuMessage
from bridge.context import Context
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix
from channel.feishu.feishu_message import FeishuMessage
from common import utils
from common.expired_dict import ExpiredDict
from common.log import logger
from common.singleton import singleton
from config import conf
from common.expired_dict import ExpiredDict
from bridge.context import ContextType
from channel.chat_channel import ChatChannel, check_prefix
from common import utils
import json
import os
URL_VERIFICATION = "url_verification"
# 尝试导入飞书SDK,如果未安装则websocket模式不可用
try:
import lark_oapi as lark
LARK_SDK_AVAILABLE = True
except ImportError:
LARK_SDK_AVAILABLE = False
logger.warning(
"[FeiShu] lark_oapi not installed, websocket mode is not available. Install with: pip install lark-oapi")
@singleton
class FeiShuChanel(ChatChannel):
feishu_app_id = conf().get('feishu_app_id')
feishu_app_secret = conf().get('feishu_app_secret')
feishu_token = conf().get('feishu_token')
feishu_event_mode = conf().get('feishu_event_mode', 'websocket') # webhook 或 websocket
def __init__(self):
super().__init__()
# 历史消息id暂存用于幂等控制
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
logger.info("[FeiShu] app_id={}, app_secret={} verification_token={}".format(
self.feishu_app_id, self.feishu_app_secret, self.feishu_token))
logger.info("[FeiShu] app_id={}, app_secret={}, verification_token={}, event_mode={}".format(
self.feishu_app_id, self.feishu_app_secret, self.feishu_token, self.feishu_event_mode))
# 无需群校验和前缀
conf()["group_name_white_list"] = ["ALL_GROUP"]
conf()["single_chat_prefix"] = [""]
# 验证配置
if self.feishu_event_mode == 'websocket' and not LARK_SDK_AVAILABLE:
logger.error("[FeiShu] websocket mode requires lark_oapi. Please install: pip install lark-oapi")
raise Exception("lark_oapi not installed")
def startup(self):
if self.feishu_event_mode == 'websocket':
self._startup_websocket()
else:
self._startup_webhook()
def _startup_webhook(self):
"""启动HTTP服务器接收事件(webhook模式)"""
logger.info("[FeiShu] Starting in webhook mode...")
urls = (
'/', 'channel.feishu.feishu_channel.FeishuController'
)
@@ -50,6 +82,109 @@ class FeiShuChanel(ChatChannel):
port = conf().get("feishu_port", 9891)
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
def _startup_websocket(self):
"""启动长连接接收事件(websocket模式)"""
logger.info("[FeiShu] Starting in websocket mode...")
# 创建事件处理器
def handle_message_event(data: lark.im.v1.P2ImMessageReceiveV1) -> None:
"""处理接收消息事件 v2.0"""
try:
logger.debug(f"[FeiShu] websocket receive event: {lark.JSON.marshal(data, indent=2)}")
# 转换为标准的event格式
event_dict = json.loads(lark.JSON.marshal(data))
event = event_dict.get("event", {})
# 处理消息
self._handle_message_event(event)
except Exception as e:
logger.error(f"[FeiShu] websocket handle message error: {e}", exc_info=True)
# 构建事件分发器
event_handler = lark.EventDispatcherHandler.builder("", "") \
.register_p2_im_message_receive_v1(handle_message_event) \
.build()
# 创建长连接客户端
ws_client = lark.ws.Client(
self.feishu_app_id,
self.feishu_app_secret,
event_handler=event_handler,
log_level=lark.LogLevel.DEBUG if conf().get("debug") else lark.LogLevel.INFO
)
# 在新线程中启动客户端,避免阻塞主线程
def start_client():
try:
logger.info("[FeiShu] Websocket client starting...")
ws_client.start()
except Exception as e:
logger.error(f"[FeiShu] Websocket client error: {e}", exc_info=True)
ws_thread = threading.Thread(target=start_client, daemon=True)
ws_thread.start()
# 保持主线程运行
logger.info("[FeiShu] Websocket mode started, waiting for events...")
ws_thread.join()
def _handle_message_event(self, event: dict):
"""
处理消息事件的核心逻辑
webhook和websocket模式共用此方法
"""
if not event.get("message") or not event.get("sender"):
logger.warning(f"[FeiShu] invalid message, event={event}")
return
msg = event.get("message")
# 幂等判断
msg_id = msg.get("message_id")
if self.receivedMsgs.get(msg_id):
logger.warning(f"[FeiShu] repeat msg filtered, msg_id={msg_id}")
return
self.receivedMsgs[msg_id] = True
is_group = False
chat_type = msg.get("chat_type")
if chat_type == "group":
if not msg.get("mentions") and msg.get("message_type") == "text":
# 群聊中未@不响应
return
if msg.get("mentions") and msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get(
"message_type") == "text":
# 不是@机器人,不响应
return
# 群聊
is_group = True
receive_id_type = "chat_id"
elif chat_type == "p2p":
receive_id_type = "open_id"
else:
logger.warning("[FeiShu] message ignore")
return
# 构造飞书消息对象
feishu_msg = FeishuMessage(event, is_group=is_group, access_token=self.fetch_access_token())
if not feishu_msg:
return
context = self._compose_context(
feishu_msg.ctype,
feishu_msg.content,
isgroup=is_group,
msg=feishu_msg,
receive_id_type=receive_id_type,
no_need_at=True
)
if context:
self.produce(context)
logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
def send(self, reply: Reply, context: Context):
msg = context.get("msg")
is_group = context["isgroup"]
@@ -143,9 +278,39 @@ class FeiShuChanel(ChatChannel):
os.remove(temp_name)
return upload_response.json().get("data").get("image_key")
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
cmsg = context["msg"]
context["session_id"] = cmsg.from_user_id
context["receiver"] = cmsg.other_user_id
if ctype == ContextType.TEXT:
# 1.文本请求
# 图片生成处理
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
if img_match_prefix:
content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = content.strip()
elif context.type == ContextType.VOICE:
# 2.语音请求
if "desire_rtype" not in context and conf().get("voice_reply_voice"):
context["desire_rtype"] = ReplyType.VOICE
return context
class FeishuController:
"""
HTTP服务器控制器用于webhook模式
"""
# 类常量
FAILED_MSG = '{"success": false}'
SUCCESS_MSG = '{"success": true}'
@@ -175,80 +340,10 @@ class FeishuController:
# 处理消息事件
event = request.get("event")
if header.get("event_type") == self.MESSAGE_RECEIVE_TYPE and event:
if not event.get("message") or not event.get("sender"):
logger.warning(f"[FeiShu] invalid message, msg={request}")
return self.FAILED_MSG
msg = event.get("message")
channel._handle_message_event(event)
# 幂等判断
if channel.receivedMsgs.get(msg.get("message_id")):
logger.warning(f"[FeiShu] repeat msg filtered, event_id={header.get('event_id')}")
return self.SUCCESS_MSG
channel.receivedMsgs[msg.get("message_id")] = True
is_group = False
chat_type = msg.get("chat_type")
if chat_type == "group":
if not msg.get("mentions") and msg.get("message_type") == "text":
# 群聊中未@不响应
return self.SUCCESS_MSG
if msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get("message_type") == "text":
# 不是@机器人,不响应
return self.SUCCESS_MSG
# 群聊
is_group = True
receive_id_type = "chat_id"
elif chat_type == "p2p":
receive_id_type = "open_id"
else:
logger.warning("[FeiShu] message ignore")
return self.SUCCESS_MSG
# 构造飞书消息对象
feishu_msg = FeishuMessage(event, is_group=is_group, access_token=channel.fetch_access_token())
if not feishu_msg:
return self.SUCCESS_MSG
context = self._compose_context(
feishu_msg.ctype,
feishu_msg.content,
isgroup=is_group,
msg=feishu_msg,
receive_id_type=receive_id_type,
no_need_at=True
)
if context:
channel.produce(context)
logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
return self.SUCCESS_MSG
except Exception as e:
logger.error(e)
return self.FAILED_MSG
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
cmsg = context["msg"]
context["session_id"] = cmsg.from_user_id
context["receiver"] = cmsg.other_user_id
if ctype == ContextType.TEXT:
# 1.文本请求
# 图片生成处理
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
if img_match_prefix:
content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = content.strip()
elif context.type == ContextType.VOICE:
# 2.语音请求
if "desire_rtype" not in context and conf().get("voice_reply_voice"):
context["desire_rtype"] = ReplyType.VOICE
return context

View File

@@ -150,9 +150,6 @@ class WebChannel(ChatChannel):
Poll for responses using the session_id.
"""
try:
# 不记录轮询请求的日志
web.ctx.log_request = False
data = web.data()
json_data = json.loads(data)
session_id = json_data.get('session_id')
@@ -198,7 +195,7 @@ class WebChannel(ChatChannel):
5. wechatcom_app: 企微自建应用
6. dingtalk: 钉钉
7. feishu: 飞书""")
logger.info(f"Web对话网页已运行, 请使用浏览器访问 http://localhost:{port}/chat本地运行或 http://ip:{port}/chat服务器运行 ")
logger.info(f"Web对话网页已运行, 请使用浏览器访问 http://localhost:{port}/chat (本地运行) 或 http://ip:{port}/chat (服务器运行)")
# 确保静态文件目录存在
static_dir = os.path.join(os.path.dirname(__file__), 'static')
@@ -215,19 +212,20 @@ class WebChannel(ChatChannel):
)
app = web.application(urls, globals(), autoreload=False)
# 禁用web.py的默认日志输出
import io
from contextlib import redirect_stdout
# 完全禁用web.py的HTTP日志输出
# 创建一个空的日志处理函数
def null_log_function(status, environ):
pass
# 配置web.py的日志级别为ERROR只显示错误
# 替换web.py的日志函数
web.httpserver.LogMiddleware.log = lambda self, status, environ: None
# 配置web.py的日志级别为ERROR
logging.getLogger("web").setLevel(logging.ERROR)
# 禁用web.httpserver的日志
logging.getLogger("web.httpserver").setLevel(logging.ERROR)
# 临时重定向标准输出捕获web.py的启动消息
with redirect_stdout(io.StringIO()):
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
# 启动服务器
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
class RootHandler:

View File

@@ -17,7 +17,7 @@
"@bot"
],
"group_name_white_list": [
"ChatGPT测试群",
"Agent测试群",
"ChatGPT测试群2"
],
"image_create_prefix": [
@@ -30,8 +30,9 @@
"expires_in_seconds": 3600,
"character_desc": "你是基于大语言模型的AI智能助手旨在回答并解决人们的任何问题并且可以使用多种语言与人交流。",
"temperature": 0.7,
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
"use_linkai": false,
"linkai_api_key": "",
"linkai_app_code": ""
"linkai_app_code": "",
"agent": false
}

View File

@@ -1,10 +1,10 @@
# encoding:utf-8
import copy
import json
import logging
import os
import pickle
import copy
from common.log import logger
@@ -148,6 +148,7 @@ available_setting = {
"feishu_app_secret": "", # 飞书机器人APP secret
"feishu_token": "", # 飞书 verification token
"feishu_bot_name": "", # 飞书机器人的名字
"feishu_event_mode": "websocket", # 飞书事件接收模式: webhook(HTTP服务器) 或 websocket(长连接)
# 钉钉配置
"dingtalk_client_id": "", # 钉钉机器人Client ID
"dingtalk_client_secret": "", # 钉钉机器人Client Secret
@@ -183,6 +184,8 @@ available_setting = {
"Minimax_group_id": "",
"Minimax_base_url": "",
"web_port": 9899,
"agent": False, # 是否开启Agent模式
"agent_workspace": "~/cow" # agent工作空间路径用于存储skills、memory等
}
@@ -197,16 +200,26 @@ class Config(dict):
self.user_datas = {}
def __getitem__(self, key):
if key not in available_setting:
# 跳过以下划线开头的注释字段
if not key.startswith("_") and key not in available_setting:
raise Exception("key {} not in available_setting".format(key))
return super().__getitem__(key)
def __setitem__(self, key, value):
if key not in available_setting:
# 跳过以下划线开头的注释字段
if not key.startswith("_") and key not in available_setting:
raise Exception("key {} not in available_setting".format(key))
return super().__setitem__(key, value)
def get(self, key, default=None):
# 跳过以下划线开头的注释字段
if key.startswith("_"):
return super().get(key, default)
# 如果key不在available_setting中直接返回default
if key not in available_setting:
return super().get(key, default)
try:
return self[key]
except KeyError as e:
@@ -284,6 +297,9 @@ def load_config():
# Some online deployment platforms (e.g. Railway) deploy project from github directly. So you shouldn't put your secrets like api key in a config file, instead use environment variables to override the default config.
for name, value in os.environ.items():
name = name.lower()
# 跳过以下划线开头的注释字段
if name.startswith("_"):
continue
if name in available_setting:
logger.info("[INIT] override config by environ args: {}={}".format(name, value))
try:

View File

@@ -15,18 +15,14 @@ elevenlabs==1.0.3 # elevenlabs TTS
#install plugin
dulwich
# wechatmp && wechatcom
# wechatmp && wechatcom && feishu
web.py
wechatpy
# chatgpt-tool-hub plugin
chatgpt_tool_hub==0.5.0
# xunfei spark
websocket-client==1.2.0
# claude bot
curl_cffi
# claude API
anthropic==0.25.0

View File

@@ -9,3 +9,6 @@ pre-commit
web.py
linkai>=0.0.6.0
agentmesh-sdk>=0.1.3
# feishu websocket mode
lark-oapi

124
skills/README.md Normal file
View File

@@ -0,0 +1,124 @@
# Skills Directory
This directory contains skills for the COW agent system. Skills are markdown files that provide specialized instructions for specific tasks.
## What are Skills?
Skills are reusable instruction sets that help the agent perform specific tasks more effectively. Each skill:
- Provides context-specific guidance
- Documents best practices
- Includes examples and usage patterns
- Can have requirements (binaries, environment variables, etc.)
## Skill Structure
Each skill is a markdown file (`SKILL.md`) in its own directory with frontmatter:
```markdown
---
name: skill-name
description: Brief description of what the skill does
metadata: {"cow":{"emoji":"🎯","requires":{"bins":["tool"]}}}
---
# Skill Name
Detailed instructions and examples...
```
## Available Skills
- **calculator**: Mathematical calculations and expressions
- **web-search**: Search the web for current information
- **file-operations**: Read, write, and manage files
## Creating Custom Skills
To create a new skill:
1. Create a directory: `skills/my-skill/`
2. Create `SKILL.md` with frontmatter and content
3. Restart the agent to load the new skill
### Frontmatter Fields
- `name`: Skill name (must match directory name)
- `description`: Brief description (required)
- `metadata`: JSON object with additional configuration
- `emoji`: Display emoji
- `always`: Always include this skill (default: false)
- `primaryEnv`: Primary environment variable needed
- `os`: Supported operating systems (e.g., ["darwin", "linux"])
- `requires`: Requirements object
- `bins`: Required binaries
- `env`: Required environment variables
- `config`: Required config paths
- `disable-model-invocation`: If true, skill won't be shown to model (default: false)
- `user-invocable`: If false, users can't invoke directly (default: true)
### Example Skill
```markdown
---
name: my-tool
description: Use my-tool to process data
metadata: {"cow":{"emoji":"🔧","requires":{"bins":["my-tool"],"env":["MY_TOOL_API_KEY"]}}}
---
# My Tool Skill
Use this skill when you need to process data with my-tool.
## Prerequisites
- Install my-tool: `pip install my-tool`
- Set `MY_TOOL_API_KEY` environment variable
## Usage
\`\`\`python
# Example usage
my_tool_command("input data")
\`\`\`
```
## Skill Loading
Skills are loaded from multiple locations with precedence:
1. **Workspace skills** (highest): `workspace/skills/` - Project-specific skills
2. **Managed skills**: `~/.cow/skills/` - User-installed skills
3. **Bundled skills** (lowest): Built-in skills
Skills with the same name in higher-precedence locations override lower ones.
## Skill Requirements
Skills can specify requirements that determine when they're available:
- **OS requirements**: Only load on specific operating systems
- **Binary requirements**: Only load if required binaries are installed
- **Environment variables**: Only load if required env vars are set
- **Config requirements**: Only load if config values are set
## Best Practices
1. **Clear descriptions**: Write clear, concise skill descriptions
2. **Include examples**: Provide practical usage examples
3. **Document prerequisites**: List all requirements clearly
4. **Use appropriate metadata**: Set correct requirements and flags
5. **Keep skills focused**: Each skill should have a single, clear purpose
## Workspace Skills
You can create workspace-specific skills in your agent's workspace:
```
workspace/
skills/
custom-skill/
SKILL.md
```
These skills are only available when working in that specific workspace.

View File

@@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,286 @@
---
name: skill-creator
description: Create or update skills. Use when designing, structuring, or packaging skills with scripts, references, and assets. COW simplified version - skills are used locally in workspace.
license: Complete terms in LICENSE.txt
---
# Skill Creator
This skill provides guidance for creating effective skills using the existing tool system.
## About Skills
Skills are modular, self-contained packages that extend the agent's capabilities by providing specialized knowledge, workflows, and tools. They transform a general-purpose agent into a specialized agent equipped with procedural knowledge.
### What Skills Provide
1. **Specialized workflows** - Multi-step procedures for specific domains
2. **Tool integrations** - Instructions for working with specific file formats or APIs
3. **Domain expertise** - Company-specific knowledge, schemas, business logic
4. **Bundled resources** - Scripts, references, and assets for complex tasks
### Core Principle
**Concise is Key**: Only add context the agent doesn't already have. Challenge each piece of information: "Does this justify its token cost?" Prefer concise examples over verbose explanations.
## Skill Structure
Every skill consists of a required SKILL.md file and optional bundled resources:
```
skill-name/
├── SKILL.md (required)
│ ├── YAML frontmatter metadata (required)
│ │ ├── name: (required)
│ │ └── description: (required)
│ └── Markdown instructions (required)
└── Bundled Resources (optional)
├── scripts/ - Executable code (Python/Bash/etc.)
├── references/ - Documentation intended to be loaded into context as needed
└── assets/ - Files used in output (templates, icons, fonts, etc.)
```
### SKILL.md Components
**Frontmatter (YAML)** - Required fields:
- **name**: Skill name in hyphen-case (e.g., `weather-api`, `pdf-editor`)
- **description**: **CRITICAL** - Primary triggering mechanism
- Must clearly describe what the skill does
- Must explicitly state when to use it
- Include specific trigger scenarios and keywords
- All "when to use" info goes here, NOT in body
- Example: `"PDF document processing with rotation, merging, splitting, and text extraction. Use when user needs to: (1) Rotate PDF pages, (2) Merge multiple PDFs, (3) Split PDF files, (4) Extract text from PDFs."`
**Body (Markdown)** - Loaded after skill triggers:
- Detailed usage instructions
- How to call scripts and read references
- Examples and best practices
- Use imperative/infinitive form ("Use X to do Y")
### Bundled Resources
**scripts/** - When to include:
- Code is repeatedly rewritten
- Deterministic execution needed (avoid LLM randomness)
- Examples: PDF rotation, image processing
- Must test scripts before including
**references/** - When to include:
- Documentation for agent to reference
- Database schemas, API docs, domain knowledge
- Agent reads these files into context as needed
- For large files (>10k words), include grep patterns in SKILL.md
**assets/** - When to include:
- Files used in output (not loaded to context)
- Templates, icons, boilerplate code
- Copied or modified in final output
**Important**: Most skills don't need all three. Choose based on actual needs.
### What NOT to Include
Do NOT create auxiliary documentation:
- README.md
- INSTALLATION_GUIDE.md
- CHANGELOG.md
- Other non-essential files
## Skill Creation Process
**COW Simplified Version** - Skills are used locally, no packaging/sharing needed.
1. **Understand** - Clarify use cases with concrete examples
2. **Plan** - Identify needed scripts, references, assets
3. **Initialize** - Run init_skill.py to create template
4. **Edit** - Implement SKILL.md and resources
5. **Validate** (optional) - Run quick_validate.py to check format
6. **Iterate** - Improve based on real usage
## Skill Naming
- Use lowercase letters, digits, and hyphens only; normalize user-provided titles to hyphen-case (e.g., "Plan Mode" -> `plan-mode`).
- When generating names, generate a name under 64 characters (letters, digits, hyphens).
- Prefer short, verb-led phrases that describe the action.
- Namespace by tool when it improves clarity or triggering (e.g., `gh-address-comments`, `linear-address-issue`).
- Name the skill folder exactly after the skill name.
## Step-by-Step Guide
### Step 1: Understanding the Skill with Concrete Examples
Skip this step only when the skill's usage patterns are already clearly understood. It remains valuable even when working with an existing skill.
To create an effective skill, clearly understand concrete examples of how the skill will be used. This understanding can come from either direct user examples or generated examples that are validated with user feedback.
For example, when building an image-editor skill, relevant questions include:
- "What functionality should the image-editor skill support? Editing, rotating, anything else?"
- "Can you give some examples of how this skill would be used?"
- "I can imagine users asking for things like 'Remove the red-eye from this image' or 'Rotate this image'. Are there other ways you imagine this skill being used?"
- "What would a user say that should trigger this skill?"
To avoid overwhelming users, avoid asking too many questions in a single message. Start with the most important questions and follow up as needed for better effectiveness.
Conclude this step when there is a clear sense of the functionality the skill should support.
### Step 2: Planning the Reusable Skill Contents
To turn concrete examples into an effective skill, analyze each example by:
1. Considering how to execute on the example from scratch
2. Identifying what scripts, references, and assets would be helpful when executing these workflows repeatedly
Example: When building a `pdf-editor` skill to handle queries like "Help me rotate this PDF," the analysis shows:
1. Rotating a PDF requires re-writing the same code each time
2. A `scripts/rotate_pdf.py` script would be helpful to store in the skill
Example: When designing a `frontend-webapp-builder` skill for queries like "Build me a todo app" or "Build me a dashboard to track my steps," the analysis shows:
1. Writing a frontend webapp requires the same boilerplate HTML/React each time
2. An `assets/hello-world/` template containing the boilerplate HTML/React project files would be helpful to store in the skill
Example: When building a `big-query` skill to handle queries like "How many users have logged in today?" the analysis shows:
1. Querying BigQuery requires re-discovering the table schemas and relationships each time
2. A `references/schema.md` file documenting the table schemas would be helpful to store in the skill
To establish the skill's contents, analyze each concrete example to create a list of the reusable resources to include: scripts, references, and assets.
### Step 3: Initialize the Skill
At this point, it is time to actually create the skill.
Skip this step only if the skill being developed already exists, and iteration is needed. In this case, continue to the next step.
When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable.
Usage:
```bash
scripts/init_skill.py <skill-name> --path <output-directory> [--resources scripts,references,assets] [--examples]
```
Examples:
```bash
scripts/init_skill.py my-skill --path ~/cow/skills
scripts/init_skill.py my-skill --path ~/cow/skills --resources scripts,references
scripts/init_skill.py my-skill --path ~/cow/skills --resources scripts --examples
```
The script:
- Creates the skill directory at the specified path
- Generates a SKILL.md template with proper frontmatter and TODO placeholders
- Optionally creates resource directories based on `--resources`
- Optionally adds example files when `--examples` is set
After initialization, customize the SKILL.md and add resources as needed. If you used `--examples`, replace or delete placeholder files.
**Important**: Always create skills in workspace directory (`~/cow/skills`), NOT in project directory.
### Step 4: Edit the Skill
When editing the (newly-generated or existing) skill, remember that the skill is being created for another instance of the agent to use. Include information that would be beneficial and non-obvious to the agent. Consider what procedural knowledge, domain-specific details, or reusable assets would help another agent instance execute these tasks more effectively.
#### Learn Proven Design Patterns
Consult these helpful guides based on your skill's needs:
- **Multi-step processes**: See references/workflows.md for sequential workflows and conditional logic
- **Specific output formats or quality standards**: See references/output-patterns.md for template and example patterns
These files contain established best practices for effective skill design.
#### Start with Reusable Skill Contents
To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`.
Added scripts must be tested by actually running them to ensure there are no bugs and that the output matches what is expected. If there are many similar scripts, only a representative sample needs to be tested to ensure confidence that they all work while balancing time to completion.
If you used `--examples`, delete any placeholder files that are not needed for the skill. Only create resource directories that are actually required.
#### Update SKILL.md
**Writing Guidelines:** Always use imperative/infinitive form.
##### Frontmatter
Write the YAML frontmatter with `name` and `description`:
- `name`: The skill name
- `description`: This is the primary triggering mechanism for your skill, and helps the agent understand when to use the skill.
- Include both what the Skill does and specific triggers/contexts for when to use it.
- Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to the agent.
- Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when the agent needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks"
Do not include any other fields in YAML frontmatter.
##### Body
Write instructions for using the skill and its bundled resources.
### Step 5: Validate (Optional)
Validate skill format:
```bash
scripts/quick_validate.py <path/to/skill-folder>
```
Example:
```bash
scripts/quick_validate.py ~/cow/skills/weather-api
```
Validation checks:
- YAML frontmatter format and required fields
- Skill naming conventions (hyphen-case, lowercase)
- Description completeness and quality
- File organization
**Note**: Validation is optional in COW. Mainly useful for troubleshooting format issues.
### Step 6: Iterate
Improve based on real usage:
1. Use skill on real tasks
2. Notice struggles or inefficiencies
3. Identify needed updates to SKILL.md or resources
4. Implement changes and test again
## Progressive Disclosure
Skills use three-level loading:
1. **Metadata** (name + description) - Always in context (~100 words)
2. **SKILL.md body** - Loaded when skill triggers (<5k words)
3. **Resources** - Loaded as needed by agent
**Best practices**:
- Keep SKILL.md under 500 lines
- Split complex content into `references/` files
- Reference these files clearly in SKILL.md
**Pattern**: For skills with multiple variants/frameworks:
- Keep core workflow in SKILL.md
- Move variant-specific details to separate reference files
- Agent loads only relevant files
Example:
```
cloud-deploy/
├── SKILL.md (workflow + provider selection)
└── references/
├── aws.md
├── gcp.md
└── azure.md
```
When user chooses AWS, agent only reads aws.md.

View File

@@ -0,0 +1,82 @@
# Output Patterns
Use these patterns when skills need to produce consistent, high-quality output.
## Template Pattern
Provide templates for output format. Match the level of strictness to your needs.
**For strict requirements (like API responses or data formats):**
```markdown
## Report structure
ALWAYS use this exact template structure:
# [Analysis Title]
## Executive summary
[One-paragraph overview of key findings]
## Key findings
- Finding 1 with supporting data
- Finding 2 with supporting data
- Finding 3 with supporting data
## Recommendations
1. Specific actionable recommendation
2. Specific actionable recommendation
```
**For flexible guidance (when adaptation is useful):**
```markdown
## Report structure
Here is a sensible default format, but use your best judgment:
# [Analysis Title]
## Executive summary
[Overview]
## Key findings
[Adapt sections based on what you discover]
## Recommendations
[Tailor to the specific context]
Adjust sections as needed for the specific analysis type.
```
## Examples Pattern
For skills where output quality depends on seeing examples, provide input/output pairs:
```markdown
## Commit message format
Generate commit messages following these examples:
**Example 1:**
Input: Added user authentication with JWT tokens
Output:
```
feat(auth): implement JWT-based authentication
Add login endpoint and token validation middleware
```
**Example 2:**
Input: Fixed bug where dates displayed incorrectly in reports
Output:
```
fix(reports): correct date formatting in timezone conversion
Use UTC timestamps consistently across report generation
```
Follow this style: type(scope): brief description, then detailed explanation.
```
Examples help Claude understand the desired style and level of detail more clearly than descriptions alone.

View File

@@ -0,0 +1,28 @@
# Workflow Patterns
## Sequential Workflows
For complex tasks, break operations into clear, sequential steps. It is often helpful to give Claude an overview of the process towards the beginning of SKILL.md:
```markdown
Filling a PDF form involves these steps:
1. Analyze the form (run analyze_form.py)
2. Create field mapping (edit fields.json)
3. Validate mapping (run validate_fields.py)
4. Fill the form (run fill_form.py)
5. Verify output (run verify_output.py)
```
## Conditional Workflows
For tasks with branching logic, guide Claude through decision points:
```markdown
1. Determine the modification type:
**Creating new content?** → Follow "Creation workflow" below
**Editing existing content?** → Follow "Editing workflow" below
2. Creation workflow: [steps]
3. Editing workflow: [steps]
```

View File

@@ -0,0 +1,303 @@
#!/usr/bin/env python3
"""
Skill Initializer - Creates a new skill from template
Usage:
init_skill.py <skill-name> --path <path>
Examples:
init_skill.py my-new-skill --path skills/public
init_skill.py my-api-helper --path skills/private
init_skill.py custom-skill --path /custom/location
"""
import sys
from pathlib import Path
SKILL_TEMPLATE = """---
name: {skill_name}
description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.]
---
# {skill_title}
## Overview
[TODO: 1-2 sentences explaining what this skill enables]
## Structuring This Skill
[TODO: Choose the structure that best fits this skill's purpose. Common patterns:
**1. Workflow-Based** (best for sequential processes)
- Works well when there are clear step-by-step procedures
- Example: DOCX skill with "Workflow Decision Tree""Reading""Creating""Editing"
- Structure: ## Overview → ## Workflow Decision Tree → ## Step 1 → ## Step 2...
**2. Task-Based** (best for tool collections)
- Works well when the skill offers different operations/capabilities
- Example: PDF skill with "Quick Start""Merge PDFs""Split PDFs""Extract Text"
- Structure: ## Overview → ## Quick Start → ## Task Category 1 → ## Task Category 2...
**3. Reference/Guidelines** (best for standards or specifications)
- Works well for brand guidelines, coding standards, or requirements
- Example: Brand styling with "Brand Guidelines""Colors""Typography""Features"
- Structure: ## Overview → ## Guidelines → ## Specifications → ## Usage...
**4. Capabilities-Based** (best for integrated systems)
- Works well when the skill provides multiple interrelated features
- Example: Product Management with "Core Capabilities" → numbered capability list
- Structure: ## Overview → ## Core Capabilities → ### 1. Feature → ### 2. Feature...
Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations).
Delete this entire "Structuring This Skill" section when done - it's just guidance.]
## [TODO: Replace with the first main section based on chosen structure]
[TODO: Add content here. See examples in existing skills:
- Code samples for technical skills
- Decision trees for complex workflows
- Concrete examples with realistic user requests
- References to scripts/templates/references as needed]
## Resources
This skill includes example resource directories that demonstrate how to organize different types of bundled resources:
### scripts/
Executable code (Python/Bash/etc.) that can be run directly to perform specific operations.
**Examples from other skills:**
- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation
- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing
**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations.
**Note:** Scripts may be executed without loading into context, but can still be read by Claude for patching or environment adjustments.
### references/
Documentation and reference material intended to be loaded into context to inform Claude's process and thinking.
**Examples from other skills:**
- Product management: `communication.md`, `context_building.md` - detailed workflow guides
- BigQuery: API reference documentation and query examples
- Finance: Schema documentation, company policies
**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Claude should reference while working.
### assets/
Files not intended to be loaded into context, but rather used within the output Claude produces.
**Examples from other skills:**
- Brand styling: PowerPoint template files (.pptx), logo files
- Frontend builder: HTML/React boilerplate project directories
- Typography: Font files (.ttf, .woff2)
**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output.
---
**Any unneeded directories can be deleted.** Not every skill requires all three types of resources.
"""
EXAMPLE_SCRIPT = '''#!/usr/bin/env python3
"""
Example helper script for {skill_name}
This is a placeholder script that can be executed directly.
Replace with actual implementation or delete if not needed.
Example real scripts from other skills:
- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields
- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images
"""
def main():
print("This is an example script for {skill_name}")
# TODO: Add actual script logic here
# This could be data processing, file conversion, API calls, etc.
if __name__ == "__main__":
main()
'''
EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title}
This is a placeholder for detailed reference documentation.
Replace with actual reference content or delete if not needed.
Example real reference docs from other skills:
- product-management/references/communication.md - Comprehensive guide for status updates
- product-management/references/context_building.md - Deep-dive on gathering context
- bigquery/references/ - API references and query examples
## When Reference Docs Are Useful
Reference docs are ideal for:
- Comprehensive API documentation
- Detailed workflow guides
- Complex multi-step processes
- Information too lengthy for main SKILL.md
- Content that's only needed for specific use cases
## Structure Suggestions
### API Reference Example
- Overview
- Authentication
- Endpoints with examples
- Error codes
- Rate limits
### Workflow Guide Example
- Prerequisites
- Step-by-step instructions
- Common patterns
- Troubleshooting
- Best practices
"""
EXAMPLE_ASSET = """# Example Asset File
This placeholder represents where asset files would be stored.
Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed.
Asset files are NOT intended to be loaded into context, but rather used within
the output Claude produces.
Example asset files from other skills:
- Brand guidelines: logo.png, slides_template.pptx
- Frontend builder: hello-world/ directory with HTML/React boilerplate
- Typography: custom-font.ttf, font-family.woff2
- Data: sample_data.csv, test_dataset.json
## Common Asset Types
- Templates: .pptx, .docx, boilerplate directories
- Images: .png, .jpg, .svg, .gif
- Fonts: .ttf, .otf, .woff, .woff2
- Boilerplate code: Project directories, starter files
- Icons: .ico, .svg
- Data files: .csv, .json, .xml, .yaml
Note: This is a text placeholder. Actual assets can be any file type.
"""
def title_case_skill_name(skill_name):
"""Convert hyphenated skill name to Title Case for display."""
return ' '.join(word.capitalize() for word in skill_name.split('-'))
def init_skill(skill_name, path):
"""
Initialize a new skill directory with template SKILL.md.
Args:
skill_name: Name of the skill
path: Path where the skill directory should be created
Returns:
Path to created skill directory, or None if error
"""
# Determine skill directory path
skill_dir = Path(path).resolve() / skill_name
# Check if directory already exists
if skill_dir.exists():
print(f"❌ Error: Skill directory already exists: {skill_dir}")
return None
# Create skill directory
try:
skill_dir.mkdir(parents=True, exist_ok=False)
print(f"✅ Created skill directory: {skill_dir}")
except Exception as e:
print(f"❌ Error creating directory: {e}")
return None
# Create SKILL.md from template
skill_title = title_case_skill_name(skill_name)
skill_content = SKILL_TEMPLATE.format(
skill_name=skill_name,
skill_title=skill_title
)
skill_md_path = skill_dir / 'SKILL.md'
try:
skill_md_path.write_text(skill_content)
print("✅ Created SKILL.md")
except Exception as e:
print(f"❌ Error creating SKILL.md: {e}")
return None
# Create resource directories with example files
try:
# Create scripts/ directory with example script
scripts_dir = skill_dir / 'scripts'
scripts_dir.mkdir(exist_ok=True)
example_script = scripts_dir / 'example.py'
example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name))
example_script.chmod(0o755)
print("✅ Created scripts/example.py")
# Create references/ directory with example reference doc
references_dir = skill_dir / 'references'
references_dir.mkdir(exist_ok=True)
example_reference = references_dir / 'api_reference.md'
example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title))
print("✅ Created references/api_reference.md")
# Create assets/ directory with example asset placeholder
assets_dir = skill_dir / 'assets'
assets_dir.mkdir(exist_ok=True)
example_asset = assets_dir / 'example_asset.txt'
example_asset.write_text(EXAMPLE_ASSET)
print("✅ Created assets/example_asset.txt")
except Exception as e:
print(f"❌ Error creating resource directories: {e}")
return None
# Print next steps
print(f"\n✅ Skill '{skill_name}' initialized successfully at {skill_dir}")
print("\nNext steps:")
print("1. Edit SKILL.md to complete the TODO items and update the description")
print("2. Customize or delete the example files in scripts/, references/, and assets/")
print("3. Run the validator when ready to check the skill structure")
return skill_dir
def main():
if len(sys.argv) < 4 or sys.argv[2] != '--path':
print("Usage: init_skill.py <skill-name> --path <path>")
print("\nSkill name requirements:")
print(" - Hyphen-case identifier (e.g., 'data-analyzer')")
print(" - Lowercase letters, digits, and hyphens only")
print(" - Max 40 characters")
print(" - Must match directory name exactly")
print("\nExamples:")
print(" init_skill.py my-new-skill --path workspace/skills")
print(" init_skill.py my-api-helper --path /path/to/skills")
print(" init_skill.py custom-skill --path /custom/location")
sys.exit(1)
skill_name = sys.argv[1]
path = sys.argv[3]
print(f"🚀 Initializing skill: {skill_name}")
print(f" Location: {path}")
print()
result = init_skill(skill_name, path)
if result:
sys.exit(0)
else:
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,116 @@
#!/usr/bin/env python3
"""
Skill Packager - Creates a distributable .skill file of a skill folder
Usage:
python utils/package_skill.py <path/to/skill-folder> [output-directory]
Example:
python utils/package_skill.py skills/public/my-skill
python utils/package_skill.py skills/public/my-skill ./dist
"""
import sys
import os
import zipfile
from pathlib import Path
# Add script directory to path for imports
script_dir = Path(__file__).parent
sys.path.insert(0, str(script_dir))
from quick_validate import validate_skill
def package_skill(skill_path, output_dir=None):
"""
Package a skill folder into a .skill file.
Args:
skill_path: Path to the skill folder
output_dir: Optional output directory for the .skill file (defaults to current directory)
Returns:
Path to the created .skill file, or None if error
"""
skill_path = Path(skill_path).resolve()
# Validate skill folder exists
if not skill_path.exists():
print(f"❌ Error: Skill folder not found: {skill_path}")
return None
if not skill_path.is_dir():
print(f"❌ Error: Path is not a directory: {skill_path}")
return None
# Validate SKILL.md exists
skill_md = skill_path / "SKILL.md"
if not skill_md.exists():
print(f"❌ Error: SKILL.md not found in {skill_path}")
return None
# Run validation before packaging
print("🔍 Validating skill...")
valid, message = validate_skill(skill_path)
if not valid:
print(f"❌ Validation failed: {message}")
print(" Please fix the validation errors before packaging.")
return None
print(f"{message}\n")
# Determine output location
skill_name = skill_path.name
if output_dir:
output_path = Path(output_dir).resolve()
output_path.mkdir(parents=True, exist_ok=True)
else:
output_path = Path.cwd()
skill_filename = output_path / f"{skill_name}.skill"
# Create the .skill file (zip format)
try:
with zipfile.ZipFile(skill_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
# Walk through the skill directory
for file_path in skill_path.rglob('*'):
if file_path.is_file():
# Calculate the relative path within the zip
arcname = file_path.relative_to(skill_path.parent)
zipf.write(file_path, arcname)
print(f" Added: {arcname}")
print(f"\n✅ Successfully packaged skill to: {skill_filename}")
return skill_filename
except Exception as e:
print(f"❌ Error creating .skill file: {e}")
return None
def main():
if len(sys.argv) < 2:
print("Usage: python utils/package_skill.py <path/to/skill-folder> [output-directory]")
print("\nExample:")
print(" python utils/package_skill.py skills/public/my-skill")
print(" python utils/package_skill.py skills/public/my-skill ./dist")
sys.exit(1)
skill_path = sys.argv[1]
output_dir = sys.argv[2] if len(sys.argv) > 2 else None
print(f"📦 Packaging skill: {skill_path}")
if output_dir:
print(f" Output directory: {output_dir}")
print()
result = package_skill(skill_path, output_dir)
if result:
sys.exit(0)
else:
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,95 @@
#!/usr/bin/env python3
"""
Quick validation script for skills - minimal version
"""
import sys
import os
import re
import yaml
from pathlib import Path
def validate_skill(skill_path):
"""Basic validation of a skill"""
skill_path = Path(skill_path)
# Check SKILL.md exists
skill_md = skill_path / 'SKILL.md'
if not skill_md.exists():
return False, "SKILL.md not found"
# Read and validate frontmatter
content = skill_md.read_text()
if not content.startswith('---'):
return False, "No YAML frontmatter found"
# Extract frontmatter
match = re.match(r'^---\n(.*?)\n---', content, re.DOTALL)
if not match:
return False, "Invalid frontmatter format"
frontmatter_text = match.group(1)
# Parse YAML frontmatter
try:
frontmatter = yaml.safe_load(frontmatter_text)
if not isinstance(frontmatter, dict):
return False, "Frontmatter must be a YAML dictionary"
except yaml.YAMLError as e:
return False, f"Invalid YAML in frontmatter: {e}"
# Define allowed properties
ALLOWED_PROPERTIES = {'name', 'description', 'license', 'allowed-tools', 'metadata'}
# Check for unexpected properties (excluding nested keys under metadata)
unexpected_keys = set(frontmatter.keys()) - ALLOWED_PROPERTIES
if unexpected_keys:
return False, (
f"Unexpected key(s) in SKILL.md frontmatter: {', '.join(sorted(unexpected_keys))}. "
f"Allowed properties are: {', '.join(sorted(ALLOWED_PROPERTIES))}"
)
# Check required fields
if 'name' not in frontmatter:
return False, "Missing 'name' in frontmatter"
if 'description' not in frontmatter:
return False, "Missing 'description' in frontmatter"
# Extract name for validation
name = frontmatter.get('name', '')
if not isinstance(name, str):
return False, f"Name must be a string, got {type(name).__name__}"
name = name.strip()
if name:
# Check naming convention (hyphen-case: lowercase with hyphens)
if not re.match(r'^[a-z0-9-]+$', name):
return False, f"Name '{name}' should be hyphen-case (lowercase letters, digits, and hyphens only)"
if name.startswith('-') or name.endswith('-') or '--' in name:
return False, f"Name '{name}' cannot start/end with hyphen or contain consecutive hyphens"
# Check name length (max 64 characters per spec)
if len(name) > 64:
return False, f"Name is too long ({len(name)} characters). Maximum is 64 characters."
# Extract and validate description
description = frontmatter.get('description', '')
if not isinstance(description, str):
return False, f"Description must be a string, got {type(description).__name__}"
description = description.strip()
if description:
# Check for angle brackets
if '<' in description or '>' in description:
return False, "Description cannot contain angle brackets (< or >)"
# Check description length (max 1024 characters per spec)
if len(description) > 1024:
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters."
return True, "Skill is valid!"
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python quick_validate.py <skill_directory>")
sys.exit(1)
valid, message = validate_skill(sys.argv[1])
print(message)
sys.exit(0 if valid else 1)