mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-19 17:10:03 +08:00
Merge pull request #2648 from zhayujie/feat-cow-agent
feat: cow agent core
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -35,3 +35,6 @@ plugins/banwords/lib/__pycache__
|
||||
!plugins/linkai
|
||||
!plugins/agent
|
||||
client_config.json
|
||||
ref/
|
||||
.cursor/
|
||||
local/
|
||||
|
||||
10
agent/memory/__init__.py
Normal file
10
agent/memory/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Memory module for AgentMesh
|
||||
|
||||
Provides long-term memory capabilities with hybrid search (vector + keyword)
|
||||
"""
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config']
|
||||
139
agent/memory/chunker.py
Normal file
139
agent/memory/chunker.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Text chunking utilities for memory
|
||||
|
||||
Splits text into chunks with token limits and overlap
|
||||
"""
|
||||
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChunk:
|
||||
"""Represents a text chunk with line numbers"""
|
||||
text: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
|
||||
|
||||
class TextChunker:
|
||||
"""Chunks text by line count with token estimation"""
|
||||
|
||||
def __init__(self, max_tokens: int = 500, overlap_tokens: int = 50):
|
||||
"""
|
||||
Initialize chunker
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum tokens per chunk
|
||||
overlap_tokens: Overlap tokens between chunks
|
||||
"""
|
||||
self.max_tokens = max_tokens
|
||||
self.overlap_tokens = overlap_tokens
|
||||
# Rough estimation: ~4 chars per token for English/Chinese mixed
|
||||
self.chars_per_token = 4
|
||||
|
||||
def chunk_text(self, text: str) -> List[TextChunk]:
|
||||
"""
|
||||
Chunk text into overlapping segments
|
||||
|
||||
Args:
|
||||
text: Input text to chunk
|
||||
|
||||
Returns:
|
||||
List of TextChunk objects
|
||||
"""
|
||||
if not text.strip():
|
||||
return []
|
||||
|
||||
lines = text.split('\n')
|
||||
chunks = []
|
||||
|
||||
max_chars = self.max_tokens * self.chars_per_token
|
||||
overlap_chars = self.overlap_tokens * self.chars_per_token
|
||||
|
||||
current_chunk = []
|
||||
current_chars = 0
|
||||
start_line = 1
|
||||
|
||||
for i, line in enumerate(lines, start=1):
|
||||
line_chars = len(line)
|
||||
|
||||
# If single line exceeds max, split it
|
||||
if line_chars > max_chars:
|
||||
# Save current chunk if exists
|
||||
if current_chunk:
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=i - 1
|
||||
))
|
||||
current_chunk = []
|
||||
current_chars = 0
|
||||
|
||||
# Split long line into multiple chunks
|
||||
for sub_chunk in self._split_long_line(line, max_chars):
|
||||
chunks.append(TextChunk(
|
||||
text=sub_chunk,
|
||||
start_line=i,
|
||||
end_line=i
|
||||
))
|
||||
|
||||
start_line = i + 1
|
||||
continue
|
||||
|
||||
# Check if adding this line would exceed limit
|
||||
if current_chars + line_chars > max_chars and current_chunk:
|
||||
# Save current chunk
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=i - 1
|
||||
))
|
||||
|
||||
# Start new chunk with overlap
|
||||
overlap_lines = self._get_overlap_lines(current_chunk, overlap_chars)
|
||||
current_chunk = overlap_lines + [line]
|
||||
current_chars = sum(len(l) for l in current_chunk)
|
||||
start_line = i - len(overlap_lines)
|
||||
else:
|
||||
# Add line to current chunk
|
||||
current_chunk.append(line)
|
||||
current_chars += line_chars
|
||||
|
||||
# Save last chunk
|
||||
if current_chunk:
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=len(lines)
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_long_line(self, line: str, max_chars: int) -> List[str]:
|
||||
"""Split a single long line into multiple chunks"""
|
||||
chunks = []
|
||||
for i in range(0, len(line), max_chars):
|
||||
chunks.append(line[i:i + max_chars])
|
||||
return chunks
|
||||
|
||||
def _get_overlap_lines(self, lines: List[str], target_chars: int) -> List[str]:
|
||||
"""Get last few lines that fit within target_chars for overlap"""
|
||||
overlap = []
|
||||
chars = 0
|
||||
|
||||
for line in reversed(lines):
|
||||
line_chars = len(line)
|
||||
if chars + line_chars > target_chars:
|
||||
break
|
||||
overlap.insert(0, line)
|
||||
chars += line_chars
|
||||
|
||||
return overlap
|
||||
|
||||
def chunk_markdown(self, text: str) -> List[TextChunk]:
|
||||
"""
|
||||
Chunk markdown text while respecting structure
|
||||
(For future enhancement: respect markdown sections)
|
||||
"""
|
||||
return self.chunk_text(text)
|
||||
114
agent/memory/config.py
Normal file
114
agent/memory/config.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Memory configuration module
|
||||
|
||||
Provides global memory configuration with simplified workspace structure
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig:
|
||||
"""Configuration for memory storage and search"""
|
||||
|
||||
# Storage paths (default: ~/cow)
|
||||
workspace_root: str = field(default_factory=lambda: os.path.expanduser("~/cow"))
|
||||
|
||||
# Embedding config
|
||||
embedding_provider: str = "openai" # "openai" | "local"
|
||||
embedding_model: str = "text-embedding-3-small"
|
||||
embedding_dim: int = 1536
|
||||
|
||||
# Chunking config
|
||||
chunk_max_tokens: int = 500
|
||||
chunk_overlap_tokens: int = 50
|
||||
|
||||
# Search config
|
||||
max_results: int = 10
|
||||
min_score: float = 0.1
|
||||
|
||||
# Hybrid search weights
|
||||
vector_weight: float = 0.7
|
||||
keyword_weight: float = 0.3
|
||||
|
||||
# Memory sources
|
||||
sources: List[str] = field(default_factory=lambda: ["memory", "session"])
|
||||
|
||||
# Sync config
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
return Path(self.workspace_root)
|
||||
|
||||
def get_memory_dir(self) -> Path:
|
||||
"""Get memory files directory"""
|
||||
return self.get_workspace() / "memory"
|
||||
|
||||
def get_db_path(self) -> Path:
|
||||
"""Get SQLite database path for long-term memory index"""
|
||||
index_dir = self.get_memory_dir() / "long-term"
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
return index_dir / "index.db"
|
||||
|
||||
def get_skills_dir(self) -> Path:
|
||||
"""Get skills directory"""
|
||||
return self.get_workspace() / "skills"
|
||||
|
||||
def get_agent_workspace(self, agent_name: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get workspace directory for an agent
|
||||
|
||||
Args:
|
||||
agent_name: Optional agent name (not used in current implementation)
|
||||
|
||||
Returns:
|
||||
Path to workspace directory
|
||||
"""
|
||||
workspace = self.get_workspace()
|
||||
# Ensure workspace directory exists
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
return workspace
|
||||
|
||||
|
||||
# Global memory configuration
|
||||
_global_memory_config: Optional[MemoryConfig] = None
|
||||
|
||||
|
||||
def get_default_memory_config() -> MemoryConfig:
|
||||
"""
|
||||
Get the global memory configuration.
|
||||
If not set, returns a default configuration.
|
||||
|
||||
Returns:
|
||||
MemoryConfig instance
|
||||
"""
|
||||
global _global_memory_config
|
||||
if _global_memory_config is None:
|
||||
_global_memory_config = MemoryConfig()
|
||||
return _global_memory_config
|
||||
|
||||
|
||||
def set_global_memory_config(config: MemoryConfig):
|
||||
"""
|
||||
Set the global memory configuration.
|
||||
This should be called before creating any MemoryManager instances.
|
||||
|
||||
Args:
|
||||
config: MemoryConfig instance to use globally
|
||||
|
||||
Example:
|
||||
>>> from agent.memory import MemoryConfig, set_global_memory_config
|
||||
>>> config = MemoryConfig(
|
||||
... workspace_root="~/my_agents",
|
||||
... embedding_provider="openai",
|
||||
... vector_weight=0.8
|
||||
... )
|
||||
>>> set_global_memory_config(config)
|
||||
"""
|
||||
global _global_memory_config
|
||||
_global_memory_config = config
|
||||
175
agent/memory/embedding.py
Normal file
175
agent/memory/embedding.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Embedding providers for memory
|
||||
|
||||
Supports OpenAI and local embedding models
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimensions(self) -> int:
|
||||
"""Get embedding dimensions"""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
Initialize OpenAI embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name (text-embedding-3-small or text-embedding-3-large)
|
||||
api_key: OpenAI API key
|
||||
api_base: Optional API base URL
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
|
||||
# Lazy import to avoid dependency issues
|
||||
try:
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI(api_key=api_key, base_url=api_base)
|
||||
except ImportError:
|
||||
raise ImportError("OpenAI package not installed. Install with: pip install openai")
|
||||
|
||||
# Set dimensions based on model
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
response = self.client.embeddings.create(
|
||||
input=text,
|
||||
model=self.model
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.model
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||
"""Local embedding provider using sentence-transformers"""
|
||||
|
||||
def __init__(self, model: str = "all-MiniLM-L6-v2"):
|
||||
"""
|
||||
Initialize local embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name from sentence-transformers
|
||||
"""
|
||||
self.model_name = model
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer(model)
|
||||
self._dimensions = self.model.get_sentence_embedding_dimension()
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"sentence-transformers not installed. "
|
||||
"Install with: pip install sentence-transformers"
|
||||
)
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
embedding = self.model.encode(text, convert_to_numpy=True)
|
||||
return embedding.tolist()
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
embeddings = self.model.encode(texts, convert_to_numpy=True)
|
||||
return embeddings.tolist()
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embeddings to avoid recomputation"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
"""Get cached embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||
"""Cache embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
self.cache[key] = embedding
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||
"""Compute cache key"""
|
||||
content = f"{provider}:{model}:{text}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
"""Clear cache"""
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
|
||||
Args:
|
||||
provider: Provider name ("openai" or "local")
|
||||
model: Model name (provider-specific)
|
||||
api_key: API key for remote providers
|
||||
api_base: API base URL for remote providers
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
if provider == "openai":
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
elif provider == "local":
|
||||
model = model or "all-MiniLM-L6-v2"
|
||||
return LocalEmbeddingProvider(model=model)
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding provider: {provider}")
|
||||
613
agent/memory/manager.py
Normal file
613
agent/memory/manager.py
Normal file
@@ -0,0 +1,613 @@
|
||||
"""
|
||||
Memory manager for AgentMesh
|
||||
|
||||
Provides high-level interface for memory operations
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||
from agent.memory.chunker import TextChunker
|
||||
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
|
||||
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Memory manager with hybrid search capabilities
|
||||
|
||||
Provides long-term memory for agents with vector and keyword search
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[MemoryConfig] = None,
|
||||
embedding_provider: Optional[EmbeddingProvider] = None,
|
||||
llm_model: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Initialize memory manager
|
||||
|
||||
Args:
|
||||
config: Memory configuration (uses global config if not provided)
|
||||
embedding_provider: Custom embedding provider (optional)
|
||||
llm_model: LLM model for summarization (optional)
|
||||
"""
|
||||
self.config = config or get_default_memory_config()
|
||||
|
||||
# Initialize storage
|
||||
db_path = self.config.get_db_path()
|
||||
self.storage = MemoryStorage(db_path)
|
||||
|
||||
# Initialize chunker
|
||||
self.chunker = TextChunker(
|
||||
max_tokens=self.config.chunk_max_tokens,
|
||||
overlap_tokens=self.config.chunk_overlap_tokens
|
||||
)
|
||||
|
||||
# Initialize embedding provider (optional)
|
||||
self.embedding_provider = None
|
||||
if embedding_provider:
|
||||
self.embedding_provider = embedding_provider
|
||||
else:
|
||||
# Try to create embedding provider, but allow failure
|
||||
try:
|
||||
# Get API key from environment or config
|
||||
api_key = os.environ.get('OPENAI_API_KEY')
|
||||
api_base = os.environ.get('OPENAI_API_BASE')
|
||||
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider=self.config.embedding_provider,
|
||||
model=self.config.embedding_model,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
except Exception as e:
|
||||
# Embedding provider failed, but that's OK
|
||||
# We can still use keyword search and file operations
|
||||
print(f"⚠️ Warning: Embedding provider initialization failed: {e}")
|
||||
print(f"ℹ️ Memory will work with keyword search only (no semantic search)")
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
self.flush_manager = MemoryFlushManager(
|
||||
workspace_dir=workspace_dir,
|
||||
llm_model=llm_model
|
||||
)
|
||||
|
||||
# Ensure workspace directories exist
|
||||
self._init_workspace()
|
||||
|
||||
self._dirty = False
|
||||
|
||||
def _init_workspace(self):
|
||||
"""Initialize workspace directories"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create default memory files
|
||||
workspace_dir = self.config.get_workspace()
|
||||
create_memory_files_if_needed(workspace_dir)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
max_results: Optional[int] = None,
|
||||
min_score: Optional[float] = None,
|
||||
include_shared: bool = True
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search memory with hybrid search (vector + keyword)
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
user_id: User ID for scoped search
|
||||
max_results: Maximum results to return
|
||||
min_score: Minimum score threshold
|
||||
include_shared: Include shared memories
|
||||
|
||||
Returns:
|
||||
List of search results sorted by relevance
|
||||
"""
|
||||
max_results = max_results or self.config.max_results
|
||||
min_score = min_score or self.config.min_score
|
||||
|
||||
# Determine scopes
|
||||
scopes = []
|
||||
if include_shared:
|
||||
scopes.append("shared")
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
if not scopes:
|
||||
return []
|
||||
|
||||
# Sync if needed
|
||||
if self.config.sync_on_search and self._dirty:
|
||||
await self.sync()
|
||||
|
||||
# Perform vector search (if embedding provider available)
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2 # Get more candidates for merging
|
||||
)
|
||||
|
||||
# Perform keyword search
|
||||
keyword_results = self.storage.search_keyword(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
vector_results,
|
||||
keyword_results,
|
||||
self.config.vector_weight,
|
||||
self.config.keyword_weight
|
||||
)
|
||||
|
||||
# Filter by min score and limit
|
||||
filtered = [r for r in merged if r.score >= min_score]
|
||||
return filtered[:max_results]
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
content: str,
|
||||
user_id: Optional[str] = None,
|
||||
scope: str = "shared",
|
||||
source: str = "memory",
|
||||
path: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Add new memory content
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
user_id: User ID for user-scoped memory
|
||||
scope: Memory scope ("shared", "user", "session")
|
||||
source: Memory source ("memory" or "session")
|
||||
path: File path (auto-generated if not provided)
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
if not content.strip():
|
||||
return
|
||||
|
||||
# Generate path if not provided
|
||||
if not path:
|
||||
content_hash = hashlib.md5(content.encode('utf-8')).hexdigest()[:8]
|
||||
if user_id and scope == "user":
|
||||
path = f"memory/users/{user_id}/memory_{content_hash}.md"
|
||||
else:
|
||||
path = f"memory/shared/memory_{content_hash}.md"
|
||||
|
||||
# Chunk content
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
|
||||
# Generate embeddings (if provider available)
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
# No embeddings, just use None
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
# Save to storage
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
self.storage.update_file_metadata(
|
||||
path=path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(os.path.getmtime(__file__)), # Use current time
|
||||
size=len(content)
|
||||
)
|
||||
|
||||
async def sync(self, force: bool = False):
|
||||
"""
|
||||
Synchronize memory from files
|
||||
|
||||
Args:
|
||||
force: Force full reindex
|
||||
"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
workspace_dir = self.config.get_workspace()
|
||||
|
||||
# Scan MEMORY.md (workspace root)
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
await self._sync_file(memory_file, "memory", "shared", None)
|
||||
|
||||
# Scan memory directory (including daily summaries)
|
||||
if memory_dir.exists():
|
||||
for file_path in memory_dir.rglob("*.md"):
|
||||
# Determine scope and user_id from path
|
||||
rel_path = file_path.relative_to(workspace_dir)
|
||||
parts = rel_path.parts
|
||||
|
||||
# Check if it's in daily summary directory
|
||||
if "daily" in parts:
|
||||
# Daily summary files
|
||||
if "users" in parts or len(parts) > 3:
|
||||
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
|
||||
user_idx = parts.index("daily") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared daily summary: memory/daily/2024-01-29.md
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
elif "users" in parts:
|
||||
# User-scoped memory
|
||||
user_idx = parts.index("users") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared memory
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
|
||||
await self._sync_file(file_path, "memory", scope, user_id)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
async def _sync_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
source: str,
|
||||
scope: str,
|
||||
user_id: Optional[str]
|
||||
):
|
||||
"""Sync a single file"""
|
||||
# Compute file hash
|
||||
content = file_path.read_text()
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
|
||||
# Get relative path
|
||||
workspace_dir = self.config.get_workspace()
|
||||
rel_path = str(file_path.relative_to(workspace_dir))
|
||||
|
||||
# Check if file changed
|
||||
stored_hash = self.storage.get_file_hash(rel_path)
|
||||
if stored_hash == file_hash:
|
||||
return # No changes
|
||||
|
||||
# Delete old chunks
|
||||
self.storage.delete_by_path(rel_path)
|
||||
|
||||
# Chunk and embed
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None
|
||||
))
|
||||
|
||||
# Save
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
stat = file_path.stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size
|
||||
)
|
||||
|
||||
def should_flush_memory(
|
||||
self,
|
||||
current_tokens: int,
|
||||
context_window: int = 128000,
|
||||
reserve_tokens: int = 20000,
|
||||
soft_threshold: int = 4000
|
||||
) -> bool:
|
||||
"""
|
||||
Check if memory flush should be triggered
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
context_window: Model's context window size (default: 128K)
|
||||
reserve_tokens: Reserve tokens for compaction overhead (default: 20K)
|
||||
soft_threshold: Trigger N tokens before threshold (default: 4K)
|
||||
|
||||
Returns:
|
||||
True if memory flush should run
|
||||
"""
|
||||
return self.flush_manager.should_flush(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
soft_threshold=soft_threshold
|
||||
)
|
||||
|
||||
async def execute_memory_flush(
|
||||
self,
|
||||
agent_executor,
|
||||
current_tokens: int,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush before compaction
|
||||
|
||||
This runs a silent agent turn to write durable memories to disk.
|
||||
Similar to clawdbot's pre-compaction memory flush.
|
||||
|
||||
Args:
|
||||
agent_executor: Async function to execute agent with prompt
|
||||
current_tokens: Current session token count
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
|
||||
Example:
|
||||
>>> async def run_agent(prompt, system_prompt, silent=False):
|
||||
... # Your agent execution logic
|
||||
... pass
|
||||
>>>
|
||||
>>> if manager.should_flush_memory(current_tokens=100000):
|
||||
... await manager.execute_memory_flush(
|
||||
... agent_executor=run_agent,
|
||||
... current_tokens=100000
|
||||
... )
|
||||
"""
|
||||
success = await self.flush_manager.execute_flush(
|
||||
agent_executor=agent_executor,
|
||||
current_tokens=current_tokens,
|
||||
user_id=user_id,
|
||||
**executor_kwargs
|
||||
)
|
||||
|
||||
if success:
|
||||
# Mark dirty so next search will sync the new memories
|
||||
self._dirty = True
|
||||
|
||||
return success
|
||||
|
||||
def build_memory_guidance(self, lang: str = "zh", include_context: bool = True) -> str:
|
||||
"""
|
||||
Build natural memory guidance for agent system prompt
|
||||
|
||||
Following clawdbot's approach:
|
||||
1. Load MEMORY.md as bootstrap context (blends into background)
|
||||
2. Load daily files on-demand via memory_search tool
|
||||
3. Agent should NOT proactively mention memories unless user asks
|
||||
|
||||
Args:
|
||||
lang: Language for guidance ("en" or "zh")
|
||||
include_context: Whether to include bootstrap memory context (default: True)
|
||||
MEMORY.md is loaded as background context (like clawdbot)
|
||||
Daily files are accessed via memory_search tool
|
||||
|
||||
Returns:
|
||||
Memory guidance text (and optionally context) for system prompt
|
||||
"""
|
||||
today_file = self.flush_manager.get_today_memory_file().name
|
||||
|
||||
if lang == "zh":
|
||||
guidance = f"""## 记忆系统
|
||||
|
||||
**背景知识**: 下方包含核心长期记忆,可直接使用。需要查找历史时,用 memory_search 搜索(搜索一次即可,不要重复)。
|
||||
|
||||
**存储记忆**: 当用户分享重要信息时(偏好、决策、事实等),主动用 write 工具存储:
|
||||
- 长期信息 → MEMORY.md
|
||||
- 当天笔记 → memory/{today_file}
|
||||
- 静默存储,仅在明确要求时确认
|
||||
|
||||
**使用原则**: 自然使用记忆,就像你本来就知道。不要主动提起或列举记忆,除非用户明确询问。"""
|
||||
else:
|
||||
guidance = f"""## Memory System
|
||||
|
||||
**Background Knowledge**: Core long-term memories below - use directly. For history, use memory_search once (don't repeat).
|
||||
|
||||
**Store Memories**: When user shares important info (preferences, decisions, facts), proactively write:
|
||||
- Durable info → MEMORY.md
|
||||
- Daily notes → memory/{today_file}
|
||||
- Store silently; confirm only when explicitly requested
|
||||
|
||||
**Usage**: Use memories naturally as if you always knew. Don't mention or list unless user explicitly asks."""
|
||||
|
||||
if include_context:
|
||||
# Load bootstrap context (MEMORY.md only, like clawdbot)
|
||||
bootstrap_context = self.load_bootstrap_memories()
|
||||
if bootstrap_context:
|
||||
guidance += f"\n\n## Background Context\n\n{bootstrap_context}"
|
||||
|
||||
return guidance
|
||||
|
||||
def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Load bootstrap memory files for session start
|
||||
|
||||
Following clawdbot's design:
|
||||
- Only loads MEMORY.md from workspace root (long-term curated memory)
|
||||
- Daily files (memory/YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap
|
||||
- User-specific MEMORY.md is also loaded if user_id provided
|
||||
|
||||
Returns memory content WITHOUT obvious headers so it blends naturally
|
||||
into the context as background knowledge.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memories
|
||||
|
||||
Returns:
|
||||
Memory content to inject into system prompt (blends naturally as background context)
|
||||
"""
|
||||
workspace_dir = self.config.get_workspace()
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
|
||||
sections = []
|
||||
|
||||
# 1. Load MEMORY.md from workspace root (long-term curated memory)
|
||||
# Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
try:
|
||||
content = memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read MEMORY.md: {e}")
|
||||
|
||||
# 2. Load user-specific MEMORY.md if user_id provided
|
||||
if user_id:
|
||||
user_memory_dir = memory_dir / "users" / user_id
|
||||
user_memory_file = user_memory_dir / "MEMORY.md"
|
||||
if user_memory_file.exists():
|
||||
try:
|
||||
content = user_memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read user memory: {e}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
# Join sections without obvious headers - let memories blend naturally
|
||||
# This makes the agent feel like it "just knows" rather than "checking memory files"
|
||||
return "\n\n".join(sections)
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get memory status"""
|
||||
stats = self.storage.get_stats()
|
||||
return {
|
||||
'chunks': stats['chunks'],
|
||||
'files': stats['files'],
|
||||
'workspace': str(self.config.get_workspace()),
|
||||
'dirty': self._dirty,
|
||||
'embedding_enabled': self.embedding_provider is not None,
|
||||
'embedding_provider': self.config.embedding_provider if self.embedding_provider else 'disabled',
|
||||
'embedding_model': self.config.embedding_model if self.embedding_provider else 'N/A',
|
||||
'search_mode': 'hybrid (vector + keyword)' if self.embedding_provider else 'keyword only (FTS5)'
|
||||
}
|
||||
|
||||
def mark_dirty(self):
|
||||
"""Mark memory as dirty (needs sync)"""
|
||||
self._dirty = True
|
||||
|
||||
def close(self):
|
||||
"""Close memory manager and release resources"""
|
||||
self.storage.close()
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _generate_chunk_id(self, path: str, start_line: int, end_line: int) -> str:
|
||||
"""Generate unique chunk ID"""
|
||||
content = f"{path}:{start_line}:{end_line}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
vector_results: List[SearchResult],
|
||||
keyword_results: List[SearchResult],
|
||||
vector_weight: float,
|
||||
keyword_weight: float
|
||||
) -> List[SearchResult]:
|
||||
"""Merge vector and keyword search results"""
|
||||
# Create a map by (path, start_line, end_line)
|
||||
merged_map = {}
|
||||
|
||||
for result in vector_results:
|
||||
key = (result.path, result.start_line, result.end_line)
|
||||
merged_map[key] = {
|
||||
'result': result,
|
||||
'vector_score': result.score,
|
||||
'keyword_score': 0.0
|
||||
}
|
||||
|
||||
for result in keyword_results:
|
||||
key = (result.path, result.start_line, result.end_line)
|
||||
if key in merged_map:
|
||||
merged_map[key]['keyword_score'] = result.score
|
||||
else:
|
||||
merged_map[key] = {
|
||||
'result': result,
|
||||
'vector_score': 0.0,
|
||||
'keyword_score': result.score
|
||||
}
|
||||
|
||||
# Calculate combined scores
|
||||
merged_results = []
|
||||
for entry in merged_map.values():
|
||||
combined_score = (
|
||||
vector_weight * entry['vector_score'] +
|
||||
keyword_weight * entry['keyword_score']
|
||||
)
|
||||
|
||||
result = entry['result']
|
||||
merged_results.append(SearchResult(
|
||||
path=result.path,
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
score=combined_score,
|
||||
snippet=result.snippet,
|
||||
source=result.source,
|
||||
user_id=result.user_id
|
||||
))
|
||||
|
||||
# Sort by score
|
||||
merged_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return merged_results
|
||||
568
agent/memory/storage.py
Normal file
568
agent/memory/storage.py
Normal file
@@ -0,0 +1,568 @@
|
||||
"""
|
||||
Storage layer for memory using SQLite + FTS5
|
||||
|
||||
Provides vector and keyword search capabilities
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import hashlib
|
||||
from typing import List, Dict, Optional, Any
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryChunk:
|
||||
"""Represents a memory chunk with text and embedding"""
|
||||
id: str
|
||||
user_id: Optional[str]
|
||||
scope: str # "shared" | "user" | "session"
|
||||
source: str # "memory" | "session"
|
||||
path: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
text: str
|
||||
embedding: Optional[List[float]]
|
||||
hash: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Search result with score and snippet"""
|
||||
path: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
score: float
|
||||
snippet: str
|
||||
source: str
|
||||
user_id: Optional[str] = None
|
||||
|
||||
|
||||
class MemoryStorage:
|
||||
"""SQLite-based storage with FTS5 for keyword search"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = db_path
|
||||
self.conn: Optional[sqlite3.Connection] = None
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database with schema"""
|
||||
try:
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
# Check database integrity
|
||||
try:
|
||||
result = self.conn.execute("PRAGMA integrity_check").fetchone()
|
||||
if result[0] != 'ok':
|
||||
print(f"⚠️ Database integrity check failed: {result[0]}")
|
||||
print(f" Recreating database...")
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
# Remove corrupted database
|
||||
self.db_path.unlink(missing_ok=True)
|
||||
# Remove WAL files
|
||||
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
|
||||
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
|
||||
# Reconnect to create new database
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
except sqlite3.DatabaseError:
|
||||
# Database is corrupted, recreate it
|
||||
print(f"⚠️ Database is corrupted, recreating...")
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
self.db_path.unlink(missing_ok=True)
|
||||
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
|
||||
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
# Enable WAL mode for better concurrency
|
||||
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Set busy timeout to avoid "database is locked" errors
|
||||
self.conn.execute("PRAGMA busy_timeout=5000")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Unexpected error during database initialization: {e}")
|
||||
raise
|
||||
|
||||
# Create chunks table with embeddings
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
scope TEXT NOT NULL DEFAULT 'shared',
|
||||
source TEXT NOT NULL DEFAULT 'memory',
|
||||
path TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
embedding TEXT,
|
||||
hash TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at INTEGER DEFAULT (strftime('%s', 'now')),
|
||||
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_user
|
||||
ON chunks(user_id)
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_scope
|
||||
ON chunks(scope)
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_hash
|
||||
ON chunks(path, hash)
|
||||
""")
|
||||
|
||||
# Create FTS5 virtual table for keyword search
|
||||
# Use default unicode61 tokenizer (stable and compatible)
|
||||
# For CJK support, we'll use LIKE queries as fallback
|
||||
self.conn.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
||||
text,
|
||||
id UNINDEXED,
|
||||
user_id UNINDEXED,
|
||||
path UNINDEXED,
|
||||
source UNINDEXED,
|
||||
scope UNINDEXED,
|
||||
content='chunks',
|
||||
content_rowid='rowid'
|
||||
)
|
||||
""")
|
||||
|
||||
# Create triggers to keep FTS in sync
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
|
||||
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
||||
DELETE FROM chunks_fts WHERE rowid = old.rowid;
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
||||
UPDATE chunks_fts SET text = new.text, id = new.id,
|
||||
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
|
||||
WHERE rowid = new.rowid;
|
||||
END
|
||||
""")
|
||||
|
||||
# Create files metadata table
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
path TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL DEFAULT 'memory',
|
||||
hash TEXT NOT NULL,
|
||||
mtime INTEGER NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||
)
|
||||
""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def save_chunk(self, chunk: MemoryChunk):
|
||||
"""Save a memory chunk"""
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", (
|
||||
chunk.id,
|
||||
chunk.user_id,
|
||||
chunk.scope,
|
||||
chunk.source,
|
||||
chunk.path,
|
||||
chunk.start_line,
|
||||
chunk.end_line,
|
||||
chunk.text,
|
||||
json.dumps(chunk.embedding) if chunk.embedding else None,
|
||||
chunk.hash,
|
||||
json.dumps(chunk.metadata) if chunk.metadata else None
|
||||
))
|
||||
self.conn.commit()
|
||||
|
||||
def save_chunks_batch(self, chunks: List[MemoryChunk]):
|
||||
"""Save multiple chunks in a batch"""
|
||||
self.conn.executemany("""
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", [
|
||||
(
|
||||
c.id, c.user_id, c.scope, c.source, c.path,
|
||||
c.start_line, c.end_line, c.text,
|
||||
json.dumps(c.embedding) if c.embedding else None,
|
||||
c.hash,
|
||||
json.dumps(c.metadata) if c.metadata else None
|
||||
)
|
||||
for c in chunks
|
||||
])
|
||||
self.conn.commit()
|
||||
|
||||
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
|
||||
"""Get a chunk by ID"""
|
||||
row = self.conn.execute("""
|
||||
SELECT * FROM chunks WHERE id = ?
|
||||
""", (chunk_id,)).fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return self._row_to_chunk(row)
|
||||
|
||||
def search_vector(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
user_id: Optional[str] = None,
|
||||
scopes: List[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Vector similarity search using in-memory cosine similarity
|
||||
(sqlite-vec can be added later for better performance)
|
||||
"""
|
||||
if scopes is None:
|
||||
scopes = ["shared"]
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
# Build query
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
params = scopes
|
||||
|
||||
if user_id:
|
||||
query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE scope IN ({scope_placeholders})
|
||||
AND (scope = 'shared' OR user_id = ?)
|
||||
AND embedding IS NOT NULL
|
||||
"""
|
||||
params.append(user_id)
|
||||
else:
|
||||
query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE scope IN ({scope_placeholders})
|
||||
AND embedding IS NOT NULL
|
||||
"""
|
||||
|
||||
rows = self.conn.execute(query, params).fetchall()
|
||||
|
||||
# Calculate cosine similarity
|
||||
results = []
|
||||
for row in rows:
|
||||
embedding = json.loads(row['embedding'])
|
||||
similarity = self._cosine_similarity(query_embedding, embedding)
|
||||
|
||||
if similarity > 0:
|
||||
results.append((similarity, row))
|
||||
|
||||
# Sort by similarity and limit
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=score,
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for score, row in results
|
||||
]
|
||||
|
||||
def search_keyword(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
scopes: List[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Keyword search using FTS5 + LIKE fallback
|
||||
|
||||
Strategy:
|
||||
1. Try FTS5 search first (good for English and word-based languages)
|
||||
2. If no results and query contains CJK characters, use LIKE search
|
||||
"""
|
||||
if scopes is None:
|
||||
scopes = ["shared"]
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
# Try FTS5 search first
|
||||
fts_results = self._search_fts5(query, user_id, scopes, limit)
|
||||
if fts_results:
|
||||
return fts_results
|
||||
|
||||
# Fallback to LIKE search for CJK characters
|
||||
if MemoryStorage._contains_cjk(query):
|
||||
return self._search_like(query, user_id, scopes, limit)
|
||||
|
||||
return []
|
||||
|
||||
def _search_fts5(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str],
|
||||
scopes: List[str],
|
||||
limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""FTS5 full-text search"""
|
||||
fts_query = self._build_fts_query(query)
|
||||
if not fts_query:
|
||||
return []
|
||||
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
params = [fts_query] + scopes
|
||||
|
||||
if user_id:
|
||||
sql_query = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||
FROM chunks_fts
|
||||
JOIN chunks ON chunks.id = chunks_fts.id
|
||||
WHERE chunks_fts MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
|
||||
ORDER BY rank
|
||||
LIMIT ?
|
||||
"""
|
||||
params.extend([user_id, limit])
|
||||
else:
|
||||
sql_query = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||
FROM chunks_fts
|
||||
JOIN chunks ON chunks.id = chunks_fts.id
|
||||
WHERE chunks_fts MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
ORDER BY rank
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
rows = self.conn.execute(sql_query, params).fetchall()
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=self._bm25_rank_to_score(row['rank']),
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _search_like(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str],
|
||||
scopes: List[str],
|
||||
limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""LIKE-based search for CJK characters"""
|
||||
import re
|
||||
# Extract CJK words (2+ characters)
|
||||
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
|
||||
if not cjk_words:
|
||||
return []
|
||||
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
|
||||
# Build LIKE conditions for each word
|
||||
like_conditions = []
|
||||
params = []
|
||||
for word in cjk_words:
|
||||
like_conditions.append("text LIKE ?")
|
||||
params.append(f'%{word}%')
|
||||
|
||||
where_clause = ' OR '.join(like_conditions)
|
||||
params.extend(scopes)
|
||||
|
||||
if user_id:
|
||||
sql_query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE ({where_clause})
|
||||
AND scope IN ({scope_placeholders})
|
||||
AND (scope = 'shared' OR user_id = ?)
|
||||
LIMIT ?
|
||||
"""
|
||||
params.extend([user_id, limit])
|
||||
else:
|
||||
sql_query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE ({where_clause})
|
||||
AND scope IN ({scope_placeholders})
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
rows = self.conn.execute(sql_query, params).fetchall()
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=0.5, # Fixed score for LIKE search
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def delete_by_path(self, path: str):
|
||||
"""Delete all chunks from a file"""
|
||||
self.conn.execute("""
|
||||
DELETE FROM chunks WHERE path = ?
|
||||
""", (path,))
|
||||
self.conn.commit()
|
||||
|
||||
def get_file_hash(self, path: str) -> Optional[str]:
|
||||
"""Get stored file hash"""
|
||||
row = self.conn.execute("""
|
||||
SELECT hash FROM files WHERE path = ?
|
||||
""", (path,)).fetchone()
|
||||
return row['hash'] if row else None
|
||||
|
||||
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
|
||||
"""Update file metadata"""
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", (path, source, file_hash, mtime, size))
|
||||
self.conn.commit()
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""Get storage statistics"""
|
||||
chunks_count = self.conn.execute("""
|
||||
SELECT COUNT(*) as cnt FROM chunks
|
||||
""").fetchone()['cnt']
|
||||
|
||||
files_count = self.conn.execute("""
|
||||
SELECT COUNT(*) as cnt FROM files
|
||||
""").fetchone()['cnt']
|
||||
|
||||
return {
|
||||
'chunks': chunks_count,
|
||||
'files': files_count
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""Close database connection"""
|
||||
if self.conn:
|
||||
try:
|
||||
self.conn.commit() # Ensure all changes are committed
|
||||
self.conn.close()
|
||||
self.conn = None # Mark as closed
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error closing database connection: {e}")
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure connection is closed"""
|
||||
try:
|
||||
self.close()
|
||||
except:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _row_to_chunk(self, row) -> MemoryChunk:
|
||||
"""Convert database row to MemoryChunk"""
|
||||
return MemoryChunk(
|
||||
id=row['id'],
|
||||
user_id=row['user_id'],
|
||||
scope=row['scope'],
|
||||
source=row['source'],
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
text=row['text'],
|
||||
embedding=json.loads(row['embedding']) if row['embedding'] else None,
|
||||
hash=row['hash'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors"""
|
||||
if len(vec1) != len(vec2):
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
||||
norm1 = sum(a * a for a in vec1) ** 0.5
|
||||
norm2 = sum(b * b for b in vec2) ** 0.5
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
@staticmethod
|
||||
def _contains_cjk(text: str) -> bool:
|
||||
"""Check if text contains CJK (Chinese/Japanese/Korean) characters"""
|
||||
import re
|
||||
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
||||
|
||||
@staticmethod
|
||||
def _build_fts_query(raw_query: str) -> Optional[str]:
|
||||
"""
|
||||
Build FTS5 query from raw text
|
||||
|
||||
Works best for English and word-based languages.
|
||||
For CJK characters, LIKE search will be used as fallback.
|
||||
"""
|
||||
import re
|
||||
# Extract words (primarily English words and numbers)
|
||||
tokens = re.findall(r'[A-Za-z0-9_]+', raw_query)
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
# Quote tokens for exact matching
|
||||
quoted = [f'"{t}"' for t in tokens]
|
||||
# Use OR for more flexible matching
|
||||
return ' OR '.join(quoted)
|
||||
|
||||
@staticmethod
|
||||
def _bm25_rank_to_score(rank: float) -> float:
|
||||
"""Convert BM25 rank to 0-1 score"""
|
||||
normalized = max(0, rank) if rank is not None else 999
|
||||
return 1 / (1 + normalized)
|
||||
|
||||
@staticmethod
|
||||
def _truncate_text(text: str, max_chars: int) -> str:
|
||||
"""Truncate text to max characters"""
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[:max_chars] + "..."
|
||||
|
||||
@staticmethod
|
||||
def compute_hash(content: str) -> str:
|
||||
"""Compute SHA256 hash of content"""
|
||||
return hashlib.sha256(content.encode('utf-8')).hexdigest()
|
||||
236
agent/memory/summarizer.py
Normal file
236
agent/memory/summarizer.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
Memory flush manager
|
||||
|
||||
Triggers memory flush before context compaction (similar to clawdbot)
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable, Any
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class MemoryFlushManager:
|
||||
"""
|
||||
Manages memory flush operations before context compaction
|
||||
|
||||
Similar to clawdbot's memory flush mechanism:
|
||||
- Triggers when context approaches token limit
|
||||
- Runs a silent agent turn to write memories to disk
|
||||
- Uses memory/YYYY-MM-DD.md for daily notes
|
||||
- Uses MEMORY.md (workspace root) for long-term curated memories
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
llm_model: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Initialize memory flush manager
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
llm_model: LLM model for agent execution (optional)
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.llm_model = llm_model
|
||||
|
||||
self.memory_dir = workspace_dir / "memory"
|
||||
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Tracking
|
||||
self.last_flush_token_count: Optional[int] = None
|
||||
self.last_flush_timestamp: Optional[datetime] = None
|
||||
|
||||
def should_flush(
|
||||
self,
|
||||
current_tokens: int,
|
||||
context_window: int,
|
||||
reserve_tokens: int = 20000,
|
||||
soft_threshold: int = 4000
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if memory flush should be triggered
|
||||
|
||||
Similar to clawdbot's shouldRunMemoryFlush logic:
|
||||
threshold = contextWindow - reserveTokens - softThreshold
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
context_window: Model's context window size
|
||||
reserve_tokens: Reserve tokens for compaction overhead
|
||||
soft_threshold: Trigger flush N tokens before threshold
|
||||
|
||||
Returns:
|
||||
True if flush should run
|
||||
"""
|
||||
if current_tokens <= 0:
|
||||
return False
|
||||
|
||||
threshold = max(0, context_window - reserve_tokens - soft_threshold)
|
||||
if threshold <= 0:
|
||||
return False
|
||||
|
||||
# Check if we've crossed the threshold
|
||||
if current_tokens < threshold:
|
||||
return False
|
||||
|
||||
# Avoid duplicate flush in same compaction cycle
|
||||
if self.last_flush_token_count is not None:
|
||||
if current_tokens <= self.last_flush_token_count + soft_threshold:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get today's memory file path: memory/YYYY-MM-DD.md
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memory
|
||||
|
||||
Returns:
|
||||
Path to today's memory file
|
||||
"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / f"{today}.md"
|
||||
else:
|
||||
return self.memory_dir / f"{today}.md"
|
||||
|
||||
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get main memory file path: MEMORY.md (workspace root)
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memory
|
||||
|
||||
Returns:
|
||||
Path to main memory file
|
||||
"""
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / "MEMORY.md"
|
||||
else:
|
||||
# Return workspace root MEMORY.md
|
||||
return Path(self.workspace_root) / "MEMORY.md"
|
||||
|
||||
def create_flush_prompt(self) -> str:
|
||||
"""
|
||||
Create prompt for memory flush turn
|
||||
|
||||
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_PROMPT
|
||||
"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
return (
|
||||
f"Pre-compaction memory flush. "
|
||||
f"Store durable memories now (use memory/{today}.md for daily notes; "
|
||||
f"create memory/ if needed). "
|
||||
f"If nothing to store, reply with NO_REPLY."
|
||||
)
|
||||
|
||||
def create_flush_system_prompt(self) -> str:
|
||||
"""
|
||||
Create system prompt for memory flush turn
|
||||
|
||||
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_SYSTEM_PROMPT
|
||||
"""
|
||||
return (
|
||||
"Pre-compaction memory flush turn. "
|
||||
"The session is near auto-compaction; capture durable memories to disk. "
|
||||
"You may reply, but usually NO_REPLY is correct."
|
||||
)
|
||||
|
||||
async def execute_flush(
|
||||
self,
|
||||
agent_executor: Callable,
|
||||
current_tokens: int,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush by running a silent agent turn
|
||||
|
||||
Args:
|
||||
agent_executor: Function to execute agent with prompt
|
||||
current_tokens: Current token count
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
"""
|
||||
try:
|
||||
# Create flush prompts
|
||||
prompt = self.create_flush_prompt()
|
||||
system_prompt = self.create_flush_system_prompt()
|
||||
|
||||
# Execute agent turn (silent, no user-visible reply expected)
|
||||
await agent_executor(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
silent=True, # NO_REPLY expected
|
||||
**executor_kwargs
|
||||
)
|
||||
|
||||
# Track flush
|
||||
self.last_flush_token_count = current_tokens
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Memory flush failed: {e}")
|
||||
return False
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get memory flush status"""
|
||||
return {
|
||||
'last_flush_tokens': self.last_flush_token_count,
|
||||
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
|
||||
'today_file': str(self.get_today_memory_file()),
|
||||
'main_file': str(self.get_main_memory_file())
|
||||
}
|
||||
|
||||
|
||||
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
|
||||
"""
|
||||
Create default memory files if they don't exist
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
user_id: Optional user ID for user-specific files
|
||||
"""
|
||||
memory_dir = workspace_dir / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create main MEMORY.md in workspace root
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
main_memory = user_dir / "MEMORY.md"
|
||||
else:
|
||||
main_memory = Path(workspace_root) / "MEMORY.md"
|
||||
|
||||
if not main_memory.exists():
|
||||
# Create empty file or with minimal structure (no obvious "Memory" header)
|
||||
# Following clawdbot's approach: memories should blend naturally into context
|
||||
main_memory.write_text("")
|
||||
|
||||
# Create today's memory file
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
today_memory = user_dir / f"{today}.md"
|
||||
else:
|
||||
today_memory = memory_dir / f"{today}.md"
|
||||
|
||||
if not today_memory.exists():
|
||||
today_memory.write_text(
|
||||
f"# Daily Memory: {today}\n\n"
|
||||
f"Day-to-day notes and running context.\n\n"
|
||||
)
|
||||
13
agent/prompt/__init__.py
Normal file
13
agent/prompt/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Agent Prompt Module - 系统提示词构建模块
|
||||
"""
|
||||
|
||||
from .builder import PromptBuilder, build_agent_system_prompt
|
||||
from .workspace import ensure_workspace, load_context_files
|
||||
|
||||
__all__ = [
|
||||
'PromptBuilder',
|
||||
'build_agent_system_prompt',
|
||||
'ensure_workspace',
|
||||
'load_context_files',
|
||||
]
|
||||
445
agent/prompt/builder.py
Normal file
445
agent/prompt/builder.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
System Prompt Builder - 系统提示词构建器
|
||||
|
||||
参考 clawdbot 的 system-prompt.ts,实现中文版的模块化提示词构建
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextFile:
|
||||
"""上下文文件"""
|
||||
path: str
|
||||
content: str
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
"""提示词构建器"""
|
||||
|
||||
def __init__(self, workspace_dir: str, language: str = "zh"):
|
||||
"""
|
||||
初始化提示词构建器
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
language: 语言 ("zh" 或 "en")
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.language = language
|
||||
|
||||
def build(
|
||||
self,
|
||||
base_persona: Optional[str] = None,
|
||||
user_identity: Optional[Dict[str, str]] = None,
|
||||
tools: Optional[List[Any]] = None,
|
||||
context_files: Optional[List[ContextFile]] = None,
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
构建完整的系统提示词
|
||||
|
||||
Args:
|
||||
base_persona: 基础人格描述(会被context_files中的SOUL.md覆盖)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表(SOUL.md, USER.md, README.md等)
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
完整的系统提示词
|
||||
"""
|
||||
return build_agent_system_prompt(
|
||||
workspace_dir=self.workspace_dir,
|
||||
language=self.language,
|
||||
base_persona=base_persona,
|
||||
user_identity=user_identity,
|
||||
tools=tools,
|
||||
context_files=context_files,
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def build_agent_system_prompt(
|
||||
workspace_dir: str,
|
||||
language: str = "zh",
|
||||
base_persona: Optional[str] = None,
|
||||
user_identity: Optional[Dict[str, str]] = None,
|
||||
tools: Optional[List[Any]] = None,
|
||||
context_files: Optional[List[ContextFile]] = None,
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
构建Agent系统提示词(精简版,中文)
|
||||
|
||||
包含的sections:
|
||||
1. 基础身份
|
||||
2. 工具说明
|
||||
3. 技能系统
|
||||
4. 记忆系统
|
||||
5. 用户身份
|
||||
6. 文档路径
|
||||
7. 工作空间
|
||||
8. 项目上下文文件
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
language: 语言 ("zh" 或 "en")
|
||||
base_persona: 基础人格描述
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
完整的系统提示词
|
||||
"""
|
||||
sections = []
|
||||
|
||||
# 1. 基础身份
|
||||
sections.extend(_build_identity_section(base_persona, language))
|
||||
|
||||
# 2. 工具说明
|
||||
if tools:
|
||||
sections.extend(_build_tooling_section(tools, language))
|
||||
|
||||
# 3. 技能系统
|
||||
if skill_manager:
|
||||
sections.extend(_build_skills_section(skill_manager, tools, language))
|
||||
|
||||
# 4. 记忆系统
|
||||
if memory_manager:
|
||||
sections.extend(_build_memory_section(memory_manager, tools, language))
|
||||
|
||||
# 5. 用户身份
|
||||
if user_identity:
|
||||
sections.extend(_build_user_identity_section(user_identity, language))
|
||||
|
||||
# 6. 工作空间
|
||||
sections.extend(_build_workspace_section(workspace_dir, language))
|
||||
|
||||
# 7. 项目上下文文件(SOUL.md, USER.md等)
|
||||
if context_files:
|
||||
sections.extend(_build_context_files_section(context_files, language))
|
||||
|
||||
# 8. 运行时信息(如果有)
|
||||
if runtime_info:
|
||||
sections.extend(_build_runtime_section(runtime_info, language))
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
def _build_identity_section(base_persona: Optional[str], language: str) -> List[str]:
|
||||
"""构建基础身份section - 不再需要,身份由SOUL.md定义"""
|
||||
# 不再生成基础身份section,完全由SOUL.md定义
|
||||
return []
|
||||
|
||||
|
||||
def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
"""构建工具说明section"""
|
||||
lines = [
|
||||
"## 工具系统",
|
||||
"",
|
||||
"你可以使用以下工具来完成任务。工具名称是大小写敏感的,请严格按照列表中的名称调用。",
|
||||
"",
|
||||
"### 可用工具",
|
||||
"",
|
||||
]
|
||||
|
||||
# 工具分类和排序
|
||||
tool_categories = {
|
||||
"文件操作": ["read", "write", "edit", "ls", "grep", "find"],
|
||||
"命令执行": ["bash", "terminal"],
|
||||
"网络搜索": ["web_search", "web_fetch", "browser"],
|
||||
"记忆系统": ["memory_search", "memory_get"],
|
||||
"其他": []
|
||||
}
|
||||
|
||||
# 构建工具映射
|
||||
tool_map = {}
|
||||
tool_descriptions = {
|
||||
"read": "读取文件内容",
|
||||
"write": "创建或覆盖文件",
|
||||
"edit": "精确编辑文件内容",
|
||||
"ls": "列出目录内容",
|
||||
"grep": "在文件中搜索内容",
|
||||
"find": "按照模式查找文件",
|
||||
"bash": "执行shell命令",
|
||||
"terminal": "管理后台进程",
|
||||
"web_search": "网络搜索(使用搜索引擎)",
|
||||
"web_fetch": "获取URL内容",
|
||||
"browser": "控制浏览器",
|
||||
"memory_search": "搜索记忆文件",
|
||||
"memory_get": "获取记忆文件内容",
|
||||
"calculator": "计算器",
|
||||
"current_time": "获取当前时间",
|
||||
}
|
||||
|
||||
for tool in tools:
|
||||
tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
tool_desc = tool.description if hasattr(tool, 'description') else tool_descriptions.get(tool_name, "")
|
||||
tool_map[tool_name] = tool_desc
|
||||
|
||||
# 按分类添加工具
|
||||
for category, tool_names in tool_categories.items():
|
||||
category_tools = [(name, tool_map.get(name, "")) for name in tool_names if name in tool_map]
|
||||
if category_tools:
|
||||
lines.append(f"**{category}**:")
|
||||
for name, desc in category_tools:
|
||||
if desc:
|
||||
lines.append(f"- `{name}`: {desc}")
|
||||
else:
|
||||
lines.append(f"- `{name}`")
|
||||
del tool_map[name] # 移除已添加的工具
|
||||
lines.append("")
|
||||
|
||||
# 添加其他未分类的工具
|
||||
if tool_map:
|
||||
lines.append("**其他工具**:")
|
||||
for name, desc in sorted(tool_map.items()):
|
||||
if desc:
|
||||
lines.append(f"- `{name}`: {desc}")
|
||||
else:
|
||||
lines.append(f"- `{name}`")
|
||||
lines.append("")
|
||||
|
||||
# 工具使用指南
|
||||
lines.extend([
|
||||
"### 工具调用风格",
|
||||
"",
|
||||
"**默认规则**: 对于常规、低风险的工具调用,无需叙述,直接调用即可。",
|
||||
"",
|
||||
"**需要叙述的情况**:",
|
||||
"- 多步骤、复杂的任务",
|
||||
"- 敏感操作(如删除文件)",
|
||||
"- 用户明确要求解释过程",
|
||||
"",
|
||||
"**完成后**: 工具调用完成后,给用户一个简短、自然的确认或回复,不要直接结束对话。",
|
||||
"",
|
||||
])
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""构建技能系统section"""
|
||||
if not skill_manager:
|
||||
return []
|
||||
|
||||
# 获取read工具名称
|
||||
read_tool_name = "read"
|
||||
if tools:
|
||||
for tool in tools:
|
||||
tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
if tool_name.lower() == "read":
|
||||
read_tool_name = tool_name
|
||||
break
|
||||
|
||||
lines = [
|
||||
"## 技能系统",
|
||||
"",
|
||||
"在回复之前:扫描下方 <available_skills> 中的 <description> 条目。",
|
||||
"",
|
||||
f"- 如果恰好有一个技能明确适用:使用 `{read_tool_name}` 工具读取其 <location> 路径下的 SKILL.md 文件,然后遵循它。",
|
||||
"- 如果多个技能都适用:选择最具体的一个,然后读取并遵循。",
|
||||
"- 如果没有明确适用的:不要读取任何 SKILL.md。",
|
||||
"",
|
||||
"**约束**: 永远不要一次性读取多个技能;只在选择后再读取。",
|
||||
"",
|
||||
]
|
||||
|
||||
# 添加技能列表(通过skill_manager获取)
|
||||
try:
|
||||
skills_prompt = skill_manager.build_skills_prompt()
|
||||
if skills_prompt:
|
||||
lines.append(skills_prompt.strip())
|
||||
lines.append("")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build skills prompt: {e}")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""构建记忆系统section"""
|
||||
if not memory_manager:
|
||||
return []
|
||||
|
||||
# 检查是否有memory工具
|
||||
has_memory_tools = False
|
||||
if tools:
|
||||
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
|
||||
has_memory_tools = any(name in ['memory_search', 'memory_get'] for name in tool_names)
|
||||
|
||||
if not has_memory_tools:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 记忆系统",
|
||||
"",
|
||||
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
|
||||
"",
|
||||
"1. 使用 `memory_search` 在 MEMORY.md 和 memory/*.md 中搜索",
|
||||
"2. 然后使用 `memory_get` 只拉取需要的行",
|
||||
"3. 如果搜索后仍然信心不足,告诉用户你已经检查过了",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆,包含重要的背景信息",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的对话和事件",
|
||||
"",
|
||||
"**使用原则**:",
|
||||
"- 自然使用记忆,就像你本来就知道",
|
||||
"- 不要主动提起或列举记忆,除非用户明确询问",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_user_identity_section(user_identity: Dict[str, str], language: str) -> List[str]:
|
||||
"""构建用户身份section"""
|
||||
if not user_identity:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 用户身份",
|
||||
"",
|
||||
]
|
||||
|
||||
if user_identity.get("name"):
|
||||
lines.append(f"**用户姓名**: {user_identity['name']}")
|
||||
if user_identity.get("nickname"):
|
||||
lines.append(f"**称呼**: {user_identity['nickname']}")
|
||||
if user_identity.get("timezone"):
|
||||
lines.append(f"**时区**: {user_identity['timezone']}")
|
||||
if user_identity.get("notes"):
|
||||
lines.append(f"**备注**: {user_identity['notes']}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_docs_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""构建文档路径section - 已移除,不再需要"""
|
||||
# 不再生成文档section
|
||||
return []
|
||||
|
||||
|
||||
def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""构建工作空间section"""
|
||||
lines = [
|
||||
"## 工作空间",
|
||||
"",
|
||||
f"你的工作目录是: `{workspace_dir}`",
|
||||
"",
|
||||
"除非用户明确指示,否则将此目录视为文件操作的全局工作空间。",
|
||||
"",
|
||||
"**重要说明 - 文件已自动加载**:",
|
||||
"",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词的「项目上下文」section 中,你**无需再用 read 工具读取它们**:",
|
||||
"",
|
||||
"- ✅ `SOUL.md`: 已加载 - Agent的人格设定",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息",
|
||||
"- ✅ `AGENTS.md`: 已加载 - 工作空间使用指南",
|
||||
"",
|
||||
"**首次对话**:",
|
||||
"",
|
||||
"如果这是你与用户的首次对话,并且你的人格设定和用户信息还是空白或初始状态:",
|
||||
"",
|
||||
"1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待",
|
||||
"2. **简短打招呼后,分点询问三个核心问题**:",
|
||||
" - 你希望我叫什么名字?",
|
||||
" - 你希望我怎么称呼你?",
|
||||
" - 你希望我们是什么样的交流风格?(这里需要举例,如:专业严谨、轻松幽默、温暖友好等)",
|
||||
"3. **语言风格**:温暖但不过度诗意,带点科技感,保持清晰",
|
||||
"4. **问题格式**:用分点或换行,让问题清晰易读;前两个问题不需要额外说明,只有交流风格需要举例",
|
||||
"5. 收到回复后,用 `write` 工具保存到 USER.md 和 SOUL.md",
|
||||
"",
|
||||
"**重要**: ",
|
||||
"- 在所有对话中,无需提及技术细节(如 SOUL.md、USER.md 等文件名,工具名称,配置等),除非用户明确询问。用自然表达如「我已记住」而非「已更新 SOUL.md」",
|
||||
"- 不要问太多其他信息(职业、时区等可以后续自然了解)",
|
||||
"- 保持简洁,避免过度抒情",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_context_files_section(context_files: List[ContextFile], language: str) -> List[str]:
|
||||
"""构建项目上下文文件section"""
|
||||
if not context_files:
|
||||
return []
|
||||
|
||||
# 检查是否有SOUL.md
|
||||
has_soul = any(
|
||||
f.path.lower().endswith('soul.md') or 'soul.md' in f.path.lower()
|
||||
for f in context_files
|
||||
)
|
||||
|
||||
lines = [
|
||||
"# 项目上下文",
|
||||
"",
|
||||
"以下项目上下文文件已被加载:",
|
||||
"",
|
||||
]
|
||||
|
||||
if has_soul:
|
||||
lines.append("如果存在 `SOUL.md`,请体现其中定义的人格和语气。避免僵硬、模板化的回复;遵循其指导,除非有更高优先级的指令覆盖它。")
|
||||
lines.append("")
|
||||
|
||||
# 添加每个文件的内容
|
||||
for file in context_files:
|
||||
lines.append(f"## {file.path}")
|
||||
lines.append("")
|
||||
lines.append(file.content)
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[str]:
|
||||
"""构建运行时信息section"""
|
||||
if not runtime_info:
|
||||
return []
|
||||
|
||||
# Only include if there's actual runtime info to display
|
||||
runtime_parts = []
|
||||
if runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={runtime_info['model']}")
|
||||
if runtime_info.get("workspace"):
|
||||
runtime_parts.append(f"工作空间={runtime_info['workspace']}")
|
||||
# Only add channel if it's not the default "web"
|
||||
if runtime_info.get("channel") and runtime_info.get("channel") != "web":
|
||||
runtime_parts.append(f"渠道={runtime_info['channel']}")
|
||||
|
||||
if not runtime_parts:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 运行时信息",
|
||||
"",
|
||||
"运行时: " + " | ".join(runtime_parts),
|
||||
""
|
||||
]
|
||||
|
||||
return lines
|
||||
314
agent/prompt/workspace.py
Normal file
314
agent/prompt/workspace.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
Workspace Management - 工作空间管理模块
|
||||
|
||||
负责初始化工作空间、创建模板文件、加载上下文文件
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
from .builder import ContextFile
|
||||
|
||||
|
||||
# 默认文件名常量
|
||||
DEFAULT_SOUL_FILENAME = "SOUL.md"
|
||||
DEFAULT_USER_FILENAME = "USER.md"
|
||||
DEFAULT_AGENTS_FILENAME = "AGENTS.md"
|
||||
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceFiles:
|
||||
"""工作空间文件路径"""
|
||||
soul_path: str
|
||||
user_path: str
|
||||
agents_path: str
|
||||
memory_path: str
|
||||
memory_dir: str
|
||||
|
||||
|
||||
def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> WorkspaceFiles:
|
||||
"""
|
||||
确保工作空间存在,并创建必要的模板文件
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录路径
|
||||
create_templates: 是否创建模板文件(首次运行时)
|
||||
|
||||
Returns:
|
||||
WorkspaceFiles对象,包含所有文件路径
|
||||
"""
|
||||
# 确保目录存在
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
|
||||
# 定义文件路径
|
||||
soul_path = os.path.join(workspace_dir, DEFAULT_SOUL_FILENAME)
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
agents_path = os.path.join(workspace_dir, DEFAULT_AGENTS_FILENAME)
|
||||
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md 在根目录
|
||||
memory_dir = os.path.join(workspace_dir, "memory") # 每日记忆子目录
|
||||
|
||||
# 创建memory子目录
|
||||
os.makedirs(memory_dir, exist_ok=True)
|
||||
|
||||
# 如果需要,创建模板文件
|
||||
if create_templates:
|
||||
_create_template_if_missing(soul_path, _get_soul_template())
|
||||
_create_template_if_missing(user_path, _get_user_template())
|
||||
_create_template_if_missing(agents_path, _get_agents_template())
|
||||
_create_template_if_missing(memory_path, _get_memory_template())
|
||||
|
||||
logger.info(f"[Workspace] Initialized workspace at: {workspace_dir}")
|
||||
|
||||
return WorkspaceFiles(
|
||||
soul_path=soul_path,
|
||||
user_path=user_path,
|
||||
agents_path=agents_path,
|
||||
memory_path=memory_path,
|
||||
memory_dir=memory_dir
|
||||
)
|
||||
|
||||
|
||||
def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] = None) -> List[ContextFile]:
|
||||
"""
|
||||
加载工作空间的上下文文件
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
files_to_load: 要加载的文件列表(相对路径),如果为None则加载所有标准文件
|
||||
|
||||
Returns:
|
||||
ContextFile对象列表
|
||||
"""
|
||||
if files_to_load is None:
|
||||
# 默认加载的文件(按优先级排序)
|
||||
files_to_load = [
|
||||
DEFAULT_SOUL_FILENAME,
|
||||
DEFAULT_USER_FILENAME,
|
||||
DEFAULT_AGENTS_FILENAME,
|
||||
]
|
||||
|
||||
context_files = []
|
||||
|
||||
for filename in files_to_load:
|
||||
filepath = os.path.join(workspace_dir, filename)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 跳过空文件或只包含模板占位符的文件
|
||||
if not content or _is_template_placeholder(content):
|
||||
continue
|
||||
|
||||
context_files.append(ContextFile(
|
||||
path=filename,
|
||||
content=content
|
||||
))
|
||||
|
||||
logger.debug(f"[Workspace] Loaded context file: {filename}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to load {filename}: {e}")
|
||||
|
||||
return context_files
|
||||
|
||||
|
||||
def _create_template_if_missing(filepath: str, template_content: str):
|
||||
"""如果文件不存在,创建模板文件"""
|
||||
if not os.path.exists(filepath):
|
||||
try:
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(template_content)
|
||||
logger.debug(f"[Workspace] Created template: {os.path.basename(filepath)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Workspace] Failed to create template {filepath}: {e}")
|
||||
|
||||
|
||||
def _is_template_placeholder(content: str) -> bool:
|
||||
"""检查内容是否为模板占位符"""
|
||||
# 常见的占位符模式
|
||||
placeholders = [
|
||||
"*(填写",
|
||||
"*(在首次对话时填写",
|
||||
"*(可选)",
|
||||
"*(根据需要添加",
|
||||
]
|
||||
|
||||
lines = content.split('\n')
|
||||
non_empty_lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
|
||||
|
||||
# 如果没有实际内容(只有标题和占位符)
|
||||
if len(non_empty_lines) <= 3:
|
||||
for placeholder in placeholders:
|
||||
if any(placeholder in line for line in non_empty_lines):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# ============= 模板内容 =============
|
||||
|
||||
def _get_soul_template() -> str:
|
||||
"""Agent人格设定模板"""
|
||||
return """# SOUL.md - 我是谁?
|
||||
|
||||
*在首次对话时与用户一起填写这个文件,定义你的身份和性格。*
|
||||
|
||||
## 基本信息
|
||||
|
||||
- **名字**: *(在首次对话时填写,可以是用户给你起的名字)*
|
||||
- **角色**: *(AI助理、智能管家、技术顾问等)*
|
||||
- **性格**: *(友好、专业、幽默、严谨等)*
|
||||
|
||||
## 交流风格
|
||||
|
||||
*(描述你如何与用户交流:)*
|
||||
- 使用什么样的语言风格?(正式/轻松/幽默)
|
||||
- 回复长度偏好?(简洁/详细)
|
||||
- 是否使用表情符号?
|
||||
|
||||
## 核心能力
|
||||
|
||||
*(你擅长什么?)*
|
||||
- 文件管理和代码编辑
|
||||
- 网络搜索和信息查询
|
||||
- 记忆管理和上下文理解
|
||||
- 任务规划和执行
|
||||
|
||||
## 行为准则
|
||||
|
||||
*(你遵循的基本原则:)*
|
||||
1. 始终在执行破坏性操作前确认
|
||||
2. 优先使用工具而不是猜测
|
||||
3. 主动记录重要信息到记忆文件
|
||||
4. 定期整理和总结对话内容
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
"""
|
||||
|
||||
|
||||
def _get_user_template() -> str:
|
||||
"""用户身份信息模板"""
|
||||
return """# USER.md - 用户基本信息
|
||||
|
||||
*这个文件只存放不会变的基本身份信息。爱好、偏好、计划等动态信息请写入 MEMORY.md。*
|
||||
|
||||
## 基本信息
|
||||
|
||||
- **姓名**: *(在首次对话时询问)*
|
||||
- **称呼**: *(用户希望被如何称呼)*
|
||||
- **职业**: *(可选)*
|
||||
- **时区**: *(例如: Asia/Shanghai)*
|
||||
|
||||
## 联系方式
|
||||
|
||||
- **微信**:
|
||||
- **邮箱**:
|
||||
- **其他**:
|
||||
|
||||
## 重要日期
|
||||
|
||||
- **生日**:
|
||||
- **纪念日**:
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这个文件存放静态的身份信息
|
||||
"""
|
||||
|
||||
|
||||
def _get_agents_template() -> str:
|
||||
"""工作空间指南模板"""
|
||||
return """# AGENTS.md - 工作空间指南
|
||||
|
||||
这个文件夹是你的家。好好对待它。
|
||||
|
||||
## 系统自动加载
|
||||
|
||||
以下文件在每次会话启动时**已经自动加载**到系统提示词中,你无需再次读取:
|
||||
|
||||
- ✅ `SOUL.md` - 你的人格设定(已加载)
|
||||
- ✅ `USER.md` - 用户信息(已加载)
|
||||
- ✅ `AGENTS.md` - 本文件(已加载)
|
||||
|
||||
## 按需读取
|
||||
|
||||
以下文件**不会自动加载**,需要时使用相应工具读取:
|
||||
|
||||
- 📝 `memory/YYYY-MM-DD.md` - 每日记忆(用 memory_search 检索)
|
||||
- 🧠 `MEMORY.md` - 长期记忆(用 memory_search 检索)
|
||||
|
||||
## 记忆系统
|
||||
|
||||
你每次会话都是全新的。这些文件是你的连续性:
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 原始的对话日志
|
||||
- 记录当天发生的事情
|
||||
- 如果 `memory/` 目录不存在,创建它
|
||||
|
||||
### 🧠 长期记忆:`MEMORY.md`
|
||||
- 你精选的记忆,就像人类的长期记忆
|
||||
- **仅在主会话中加载**(与用户的直接聊天)
|
||||
- **不要在共享上下文中加载**(群聊、与其他人的会话)
|
||||
- 这是为了**安全** - 包含不应泄露给陌生人的个人上下文
|
||||
- 记录重要事件、想法、决定、观点、经验教训
|
||||
- 这是你精选的记忆 - 精华,而不是原始日志
|
||||
- 用 `edit` 工具追加新的记忆内容
|
||||
|
||||
### 📝 写下来 - 不要"记在心里"!
|
||||
- **记忆是有限的** - 如果你想记住某事,写入文件
|
||||
- "记在心里"不会在会话重启后保留,文件才会
|
||||
- 当有人说"记住这个" → 更新 `MEMORY.md` 或 `memory/YYYY-MM-DD.md`
|
||||
- 当你学到教训 → 更新 AGENTS.md 或相关技能
|
||||
- 当你犯错 → 记录下来,这样未来的你不会重复
|
||||
- **文字 > 大脑** 📝
|
||||
|
||||
### 存储规则
|
||||
|
||||
当用户分享信息时,根据类型选择存储位置:
|
||||
|
||||
1. **静态身份 → USER.md**(仅限:姓名、职业、时区、联系方式、生日)
|
||||
2. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项)
|
||||
3. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
|
||||
**重要**:
|
||||
- 爱好(唱歌、篮球等)→ MEMORY.md,不是 USER.md
|
||||
- 近期计划(下周要做什么)→ MEMORY.md,不是 USER.md
|
||||
- USER.md 只存放不会变的基本信息
|
||||
|
||||
## 安全
|
||||
|
||||
- 永远不要泄露私人数据
|
||||
- 不要在未经询问的情况下运行破坏性命令
|
||||
- 当有疑问时,先问
|
||||
|
||||
## 工具使用
|
||||
|
||||
技能提供你的工具。当你需要一个时,查看它的 `SKILL.md`。
|
||||
|
||||
## 让它成为你的
|
||||
|
||||
这只是一个起点。随着你弄清楚什么有效,添加你自己的约定、风格和规则。
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_template() -> str:
|
||||
"""长期记忆模板 - 创建一个空文件,由 Agent 自己填充"""
|
||||
return """# MEMORY.md - 长期记忆
|
||||
|
||||
*这是你的长期记忆文件。记录重要的事件、决策、偏好、学到的教训。*
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
|
||||
|
||||
20
agent/protocol/__init__.py
Normal file
20
agent/protocol/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .agent import Agent
|
||||
from .agent_stream import AgentStreamExecutor
|
||||
from .task import Task, TaskType, TaskStatus
|
||||
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
|
||||
from .models import LLMModel, LLMRequest, ModelFactory
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'AgentStreamExecutor',
|
||||
'Task',
|
||||
'TaskType',
|
||||
'TaskStatus',
|
||||
'AgentResult',
|
||||
'AgentAction',
|
||||
'AgentActionType',
|
||||
'ToolResult',
|
||||
'LLMModel',
|
||||
'LLMRequest',
|
||||
'ModelFactory'
|
||||
]
|
||||
371
agent/protocol/agent.py
Normal file
371
agent/protocol/agent.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
from agent.protocol.result import AgentAction, AgentActionType, ToolResult, AgentResult
|
||||
from agent.tools.base_tool import BaseTool, ToolStage
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
|
||||
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
|
||||
context_reserve_tokens=None, memory_manager=None, name: str = None,
|
||||
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True):
|
||||
"""
|
||||
Initialize the Agent with system prompt, model, description.
|
||||
|
||||
:param system_prompt: The system prompt for the agent.
|
||||
:param description: A description of the agent.
|
||||
:param model: An instance of LLMModel to be used by the agent.
|
||||
:param tools: Optional list of tools for the agent to use.
|
||||
:param output_mode: Control how execution progress is displayed:
|
||||
"print" for console output or "logger" for using logger
|
||||
:param max_steps: Maximum number of steps the agent can take (default: 100)
|
||||
:param max_context_tokens: Maximum tokens to keep in context (default: None, auto-calculated based on model)
|
||||
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
|
||||
:param memory_manager: Optional MemoryManager instance for memory operations
|
||||
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
|
||||
:param workspace_dir: Optional workspace directory for workspace-specific skills
|
||||
:param skill_manager: Optional SkillManager instance (will be created if None and enable_skills=True)
|
||||
:param enable_skills: Whether to enable skills support (default: True)
|
||||
"""
|
||||
self.name = name or "Agent"
|
||||
self.system_prompt = system_prompt
|
||||
self.model: LLMModel = model # Instance of LLMModel
|
||||
self.description = description
|
||||
self.tools: list = []
|
||||
self.max_steps = max_steps # max tool-call steps, default 100
|
||||
self.max_context_tokens = max_context_tokens # max tokens in context
|
||||
self.context_reserve_tokens = context_reserve_tokens # reserve tokens for new requests
|
||||
self.captured_actions = [] # Initialize captured actions list
|
||||
self.output_mode = output_mode
|
||||
self.last_usage = None # Store last API response usage info
|
||||
self.messages = [] # Unified message history for stream mode
|
||||
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
||||
self.workspace_dir = workspace_dir # Workspace directory
|
||||
self.enable_skills = enable_skills # Skills enabled flag
|
||||
|
||||
# Initialize skill manager
|
||||
self.skill_manager = None
|
||||
if enable_skills:
|
||||
if skill_manager:
|
||||
self.skill_manager = skill_manager
|
||||
else:
|
||||
# Auto-create skill manager
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
self.skill_manager = SkillManager(workspace_dir=workspace_dir)
|
||||
logger.info(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SkillManager: {e}")
|
||||
|
||||
if tools:
|
||||
for tool in tools:
|
||||
self.add_tool(tool)
|
||||
|
||||
def add_tool(self, tool: BaseTool):
|
||||
"""
|
||||
Add a tool to the agent.
|
||||
|
||||
:param tool: The tool to add (either a tool instance or a tool name)
|
||||
"""
|
||||
# If tool is already an instance, use it directly
|
||||
tool.model = self.model
|
||||
self.tools.append(tool)
|
||||
|
||||
def get_skills_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Get the skills prompt to append to system prompt.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt or empty string
|
||||
"""
|
||||
if not self.skill_manager:
|
||||
return ""
|
||||
|
||||
try:
|
||||
return self.skill_manager.build_skills_prompt(skill_filter=skill_filter)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build skills prompt: {e}")
|
||||
return ""
|
||||
|
||||
def get_full_system_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Get the full system prompt including skills.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Complete system prompt with skills appended
|
||||
"""
|
||||
base_prompt = self.system_prompt
|
||||
skills_prompt = self.get_skills_prompt(skill_filter=skill_filter)
|
||||
|
||||
if skills_prompt:
|
||||
return base_prompt + "\n" + skills_prompt
|
||||
return base_prompt
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Refresh the loaded skills."""
|
||||
if self.skill_manager:
|
||||
self.skill_manager.refresh_skills()
|
||||
logger.info(f"Refreshed skills: {len(self.skill_manager.skills)} skills loaded")
|
||||
|
||||
def list_skills(self):
|
||||
"""
|
||||
List all loaded skills.
|
||||
|
||||
:return: List of skill entries or empty list
|
||||
"""
|
||||
if not self.skill_manager:
|
||||
return []
|
||||
return self.skill_manager.list_skills()
|
||||
|
||||
def _get_model_context_window(self) -> int:
|
||||
"""
|
||||
Get the model's context window size in tokens.
|
||||
Auto-detect based on model name.
|
||||
|
||||
Model context windows:
|
||||
- Claude 3.5/3.7 Sonnet: 200K tokens
|
||||
- Claude 3 Opus: 200K tokens
|
||||
- GPT-4 Turbo/128K: 128K tokens
|
||||
- GPT-4: 8K-32K tokens
|
||||
- GPT-3.5: 16K tokens
|
||||
- DeepSeek: 64K tokens
|
||||
|
||||
:return: Context window size in tokens
|
||||
"""
|
||||
if self.model and hasattr(self.model, 'model'):
|
||||
model_name = self.model.model.lower()
|
||||
|
||||
# Claude models - 200K context
|
||||
if 'claude-3' in model_name or 'claude-sonnet' in model_name:
|
||||
return 200000
|
||||
|
||||
# GPT-4 models
|
||||
elif 'gpt-4' in model_name:
|
||||
if 'turbo' in model_name or '128k' in model_name:
|
||||
return 128000
|
||||
elif '32k' in model_name:
|
||||
return 32000
|
||||
else:
|
||||
return 8000
|
||||
|
||||
# GPT-3.5
|
||||
elif 'gpt-3.5' in model_name:
|
||||
if '16k' in model_name:
|
||||
return 16000
|
||||
else:
|
||||
return 4000
|
||||
|
||||
# DeepSeek
|
||||
elif 'deepseek' in model_name:
|
||||
return 64000
|
||||
|
||||
# Gemini models
|
||||
elif 'gemini' in model_name:
|
||||
if '2.0' in model_name or 'exp' in model_name:
|
||||
return 2000000 # Gemini 2.0: 2M tokens
|
||||
else:
|
||||
return 1000000 # Gemini 1.5: 1M tokens
|
||||
|
||||
# Default conservative value
|
||||
return 128000
|
||||
|
||||
def _get_context_reserve_tokens(self) -> int:
|
||||
"""
|
||||
Get the number of tokens to reserve for new requests.
|
||||
This prevents context overflow by keeping a buffer.
|
||||
|
||||
:return: Number of tokens to reserve
|
||||
"""
|
||||
if self.context_reserve_tokens is not None:
|
||||
return self.context_reserve_tokens
|
||||
|
||||
# Reserve ~10% of context window, with min 10K and max 200K
|
||||
context_window = self._get_model_context_window()
|
||||
reserve = int(context_window * 0.1)
|
||||
return max(10000, min(200000, reserve))
|
||||
|
||||
def _estimate_message_tokens(self, message: dict) -> int:
|
||||
"""
|
||||
Estimate token count for a message using chars/4 heuristic.
|
||||
This is a conservative estimate (tends to overestimate).
|
||||
|
||||
:param message: Message dict with 'role' and 'content'
|
||||
:return: Estimated token count
|
||||
"""
|
||||
content = message.get('content', '')
|
||||
if isinstance(content, str):
|
||||
return max(1, len(content) // 4)
|
||||
elif isinstance(content, list):
|
||||
# Handle multi-part content (text + images)
|
||||
total_chars = 0
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get('type') == 'text':
|
||||
total_chars += len(part.get('text', ''))
|
||||
elif isinstance(part, dict) and part.get('type') == 'image':
|
||||
# Estimate images as ~1200 tokens
|
||||
total_chars += 4800
|
||||
return max(1, total_chars // 4)
|
||||
return 1
|
||||
|
||||
def _find_tool(self, tool_name: str):
|
||||
"""Find and return a tool with the specified name"""
|
||||
for tool in self.tools:
|
||||
if tool.name == tool_name:
|
||||
# Only pre-process stage tools can be actively called
|
||||
if tool.stage == ToolStage.PRE_PROCESS:
|
||||
tool.model = self.model
|
||||
tool.context = self # Set tool context
|
||||
return tool
|
||||
else:
|
||||
# If it's a post-process tool, return None to prevent direct calling
|
||||
logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.")
|
||||
return None
|
||||
return None
|
||||
|
||||
# output function based on mode
|
||||
def output(self, message="", end="\n"):
|
||||
if self.output_mode == "print":
|
||||
print(message, end=end)
|
||||
elif message:
|
||||
logger.info(message)
|
||||
|
||||
def _execute_post_process_tools(self):
|
||||
"""Execute all post-process stage tools"""
|
||||
# Get all post-process stage tools
|
||||
post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS]
|
||||
|
||||
# Execute each tool
|
||||
for tool in post_process_tools:
|
||||
# Set tool context
|
||||
tool.context = self
|
||||
|
||||
# Record start time for execution timing
|
||||
start_time = time.time()
|
||||
|
||||
# Execute tool (with empty parameters, tool will extract needed info from context)
|
||||
result = tool.execute({})
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Capture tool use for tracking
|
||||
self.capture_tool_use(
|
||||
tool_name=tool.name,
|
||||
input_params={}, # Post-process tools typically don't take parameters
|
||||
output=result.result,
|
||||
status=result.status,
|
||||
error_message=str(result.result) if result.status == "error" else None,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# Log result
|
||||
if result.status == "success":
|
||||
# Print tool execution result in the desired format
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps(result.result)}")
|
||||
else:
|
||||
# Print failure in print mode
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps({'status': 'error', 'message': str(result.result)})}")
|
||||
|
||||
def capture_tool_use(self, tool_name, input_params, output, status, thought=None, error_message=None,
|
||||
execution_time=0.0):
|
||||
"""
|
||||
Capture a tool use action.
|
||||
|
||||
:param thought: thought content
|
||||
:param tool_name: Name of the tool used
|
||||
:param input_params: Parameters passed to the tool
|
||||
:param output: Output from the tool
|
||||
:param status: Status of the tool execution
|
||||
:param error_message: Error message if the tool execution failed
|
||||
:param execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_result = ToolResult(
|
||||
tool_name=tool_name,
|
||||
input_params=input_params,
|
||||
output=output,
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
action = AgentAction(
|
||||
agent_id=self.id if hasattr(self, 'id') else str(id(self)),
|
||||
agent_name=self.name,
|
||||
action_type=AgentActionType.TOOL_USE,
|
||||
tool_result=tool_result,
|
||||
thought=thought
|
||||
)
|
||||
|
||||
self.captured_actions.append(action)
|
||||
|
||||
return action
|
||||
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False, skill_filter=None) -> str:
|
||||
"""
|
||||
Execute single agent task with streaming (based on tool-call)
|
||||
|
||||
This method supports:
|
||||
- Streaming output
|
||||
- Multi-turn reasoning based on tool-call
|
||||
- Event callbacks
|
||||
- Persistent conversation history across calls
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
on_event: Event callback function callback(event: dict)
|
||||
event = {"type": str, "timestamp": float, "data": dict}
|
||||
clear_history: If True, clear conversation history before this call (default: False)
|
||||
skill_filter: Optional list of skill names to include in this run
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
|
||||
Example:
|
||||
# Multi-turn conversation with memory
|
||||
response1 = agent.run_stream("My name is Alice")
|
||||
response2 = agent.run_stream("What's my name?") # Will remember Alice
|
||||
|
||||
# Single-turn without memory
|
||||
response = agent.run_stream("Hello", clear_history=True)
|
||||
"""
|
||||
# Clear history if requested
|
||||
if clear_history:
|
||||
self.messages = []
|
||||
|
||||
# Get model to use
|
||||
if not self.model:
|
||||
raise ValueError("No model available for agent")
|
||||
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = self.get_full_system_prompt(skill_filter=skill_filter)
|
||||
|
||||
# Create stream executor with agent's message history
|
||||
executor = AgentStreamExecutor(
|
||||
agent=self,
|
||||
model=self.model,
|
||||
system_prompt=full_system_prompt,
|
||||
tools=self.tools,
|
||||
max_turns=self.max_steps,
|
||||
on_event=on_event,
|
||||
messages=self.messages # Pass agent's message history
|
||||
)
|
||||
|
||||
# Execute
|
||||
response = executor.run_stream(user_message)
|
||||
|
||||
# Update agent's message history from executor
|
||||
self.messages = executor.messages
|
||||
|
||||
# Execute all post-process tools
|
||||
self._execute_post_process_tools()
|
||||
|
||||
return response
|
||||
|
||||
def clear_history(self):
|
||||
"""Clear conversation history and captured actions"""
|
||||
self.messages = []
|
||||
self.captured_actions = []
|
||||
478
agent/protocol/agent_stream.py
Normal file
478
agent/protocol/agent_stream.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
Agent Stream Execution Module - Multi-turn reasoning based on tool-call
|
||||
|
||||
Provides streaming output, event system, and complete tool-call loop
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class AgentStreamExecutor:
|
||||
"""
|
||||
Agent Stream Executor
|
||||
|
||||
Handles multi-turn reasoning loop based on tool-call:
|
||||
1. LLM generates response (may include tool calls)
|
||||
2. Execute tools
|
||||
3. Return results to LLM
|
||||
4. Repeat until no more tool calls
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent, # Agent instance
|
||||
model: LLMModel,
|
||||
system_prompt: str,
|
||||
tools: List[BaseTool],
|
||||
max_turns: int = 50,
|
||||
on_event: Optional[Callable] = None,
|
||||
messages: Optional[List[Dict]] = None
|
||||
):
|
||||
"""
|
||||
Initialize stream executor
|
||||
|
||||
Args:
|
||||
agent: Agent instance (for accessing context)
|
||||
model: LLM model
|
||||
system_prompt: System prompt
|
||||
tools: List of available tools
|
||||
max_turns: Maximum number of turns
|
||||
on_event: Event callback function
|
||||
messages: Optional existing message history (for persistent conversations)
|
||||
"""
|
||||
self.agent = agent
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
# Convert tools list to dict
|
||||
self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools
|
||||
self.max_turns = max_turns
|
||||
self.on_event = on_event
|
||||
|
||||
# Message history - use provided messages or create new list
|
||||
self.messages = messages if messages is not None else []
|
||||
|
||||
def _emit_event(self, event_type: str, data: dict = None):
|
||||
"""Emit event"""
|
||||
if self.on_event:
|
||||
try:
|
||||
self.on_event({
|
||||
"type": event_type,
|
||||
"timestamp": time.time(),
|
||||
"data": data or {}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Event callback error: {e}")
|
||||
|
||||
def run_stream(self, user_message: str) -> str:
|
||||
"""
|
||||
Execute streaming reasoning loop
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
"""
|
||||
# Log user message
|
||||
logger.info(f"\n{'='*50}")
|
||||
logger.info(f"👤 用户: {user_message}")
|
||||
logger.info(f"{'='*50}")
|
||||
|
||||
# Add user message (Claude format - use content blocks for consistency)
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_message
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
self._emit_event("agent_start")
|
||||
|
||||
final_response = ""
|
||||
turn = 0
|
||||
|
||||
try:
|
||||
while turn < self.max_turns:
|
||||
turn += 1
|
||||
logger.info(f"\n{'='*50} 第 {turn} 轮 {'='*50}")
|
||||
self._emit_event("turn_start", {"turn": turn})
|
||||
|
||||
# Check if memory flush is needed (before calling LLM)
|
||||
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
|
||||
usage = self.agent.last_usage
|
||||
if usage and 'input_tokens' in usage:
|
||||
current_tokens = usage.get('input_tokens', 0)
|
||||
context_window = self.agent._get_model_context_window()
|
||||
# Use configured reserve_tokens or calculate based on context window
|
||||
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||
# Use smaller soft_threshold to trigger flush earlier (e.g., at 50K tokens)
|
||||
soft_threshold = 10000 # Trigger 10K tokens before limit
|
||||
|
||||
if self.agent.memory_manager.should_flush_memory(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
soft_threshold=soft_threshold
|
||||
):
|
||||
self._emit_event("memory_flush_start", {
|
||||
"current_tokens": current_tokens,
|
||||
"threshold": context_window - reserve_tokens - soft_threshold
|
||||
})
|
||||
|
||||
# TODO: Execute memory flush in background
|
||||
# This would require async support
|
||||
logger.info(f"Memory flush recommended at {current_tokens} tokens")
|
||||
|
||||
# Call LLM
|
||||
assistant_msg, tool_calls = self._call_llm_stream()
|
||||
final_response = assistant_msg
|
||||
|
||||
# No tool calls, end loop
|
||||
if not tool_calls:
|
||||
if assistant_msg:
|
||||
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
|
||||
logger.info(f"✅ 完成 (无工具调用)")
|
||||
self._emit_event("turn_end", {
|
||||
"turn": turn,
|
||||
"has_tool_calls": False
|
||||
})
|
||||
break
|
||||
|
||||
# Log tool calls in compact format
|
||||
tool_names = [tc['name'] for tc in tool_calls]
|
||||
logger.info(f"🔧 调用工具: {', '.join(tool_names)}")
|
||||
|
||||
# Execute tools
|
||||
tool_results = []
|
||||
tool_result_blocks = []
|
||||
|
||||
for tool_call in tool_calls:
|
||||
result = self._execute_tool(tool_call)
|
||||
tool_results.append(result)
|
||||
|
||||
# Log tool result in compact format
|
||||
status_emoji = "✅" if result.get("status") == "success" else "❌"
|
||||
result_data = result.get('result', '')
|
||||
# Format result string with proper Chinese character support
|
||||
if isinstance(result_data, (dict, list)):
|
||||
result_str = json.dumps(result_data, ensure_ascii=False)
|
||||
else:
|
||||
result_str = str(result_data)
|
||||
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
|
||||
|
||||
# Build tool result block (Claude format)
|
||||
# Content should be a string representation of the result
|
||||
result_content = json.dumps(result, ensure_ascii=False) if not isinstance(result, str) else result
|
||||
tool_result_blocks.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call["id"],
|
||||
"content": result_content
|
||||
})
|
||||
|
||||
# Add tool results to message history as user message (Claude format)
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": tool_result_blocks
|
||||
})
|
||||
|
||||
self._emit_event("turn_end", {
|
||||
"turn": turn,
|
||||
"has_tool_calls": True,
|
||||
"tool_count": len(tool_calls)
|
||||
})
|
||||
|
||||
if turn >= self.max_turns:
|
||||
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Agent执行错误: {e}")
|
||||
self._emit_event("error", {"error": str(e)})
|
||||
raise
|
||||
|
||||
finally:
|
||||
logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n")
|
||||
self._emit_event("agent_end", {"final_response": final_response})
|
||||
|
||||
return final_response
|
||||
|
||||
def _call_llm_stream(self) -> tuple[str, List[Dict]]:
|
||||
"""
|
||||
Call LLM with streaming
|
||||
|
||||
Returns:
|
||||
(response_text, tool_calls)
|
||||
"""
|
||||
# Trim messages if needed (using agent's context management)
|
||||
self._trim_messages()
|
||||
|
||||
# Prepare messages
|
||||
messages = self._prepare_messages()
|
||||
|
||||
# Debug: log message structure
|
||||
logger.debug(f"Sending {len(messages)} messages to LLM")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
content_types = [c.get("type") for c in content if isinstance(c, dict)]
|
||||
logger.debug(f" Message {i}: role={role}, content_blocks={content_types}")
|
||||
else:
|
||||
logger.debug(f" Message {i}: role={role}, content_length={len(str(content))}")
|
||||
|
||||
# Prepare tool definitions (OpenAI/Claude format)
|
||||
tools_schema = None
|
||||
if self.tools:
|
||||
tools_schema = []
|
||||
for tool in self.tools.values():
|
||||
tools_schema.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.params # Claude uses input_schema
|
||||
})
|
||||
|
||||
# Create request
|
||||
request = LLMRequest(
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
stream=True,
|
||||
tools=tools_schema,
|
||||
system=self.system_prompt # Pass system prompt separately for Claude API
|
||||
)
|
||||
|
||||
self._emit_event("message_start", {"role": "assistant"})
|
||||
|
||||
# Streaming response
|
||||
full_content = ""
|
||||
tool_calls_buffer = {} # {index: {id, name, arguments}}
|
||||
|
||||
try:
|
||||
stream = self.model.call_stream(request)
|
||||
|
||||
for chunk in stream:
|
||||
# Check for errors
|
||||
if isinstance(chunk, dict) and chunk.get("error"):
|
||||
error_msg = chunk.get("message", "Unknown error")
|
||||
status_code = chunk.get("status_code", "N/A")
|
||||
logger.error(f"API Error: {error_msg} (Status: {status_code})")
|
||||
logger.error(f"Full error chunk: {chunk}")
|
||||
raise Exception(f"{error_msg} (Status: {status_code})")
|
||||
|
||||
# Parse chunk
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choice = chunk["choices"][0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Handle text content
|
||||
if "content" in delta and delta["content"]:
|
||||
content_delta = delta["content"]
|
||||
full_content += content_delta
|
||||
self._emit_event("message_update", {"delta": content_delta})
|
||||
|
||||
# Handle tool calls
|
||||
if "tool_calls" in delta:
|
||||
for tc_delta in delta["tool_calls"]:
|
||||
index = tc_delta.get("index", 0)
|
||||
|
||||
if index not in tool_calls_buffer:
|
||||
tool_calls_buffer[index] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": ""
|
||||
}
|
||||
|
||||
if "id" in tc_delta:
|
||||
tool_calls_buffer[index]["id"] = tc_delta["id"]
|
||||
|
||||
if "function" in tc_delta:
|
||||
func = tc_delta["function"]
|
||||
if "name" in func:
|
||||
tool_calls_buffer[index]["name"] = func["name"]
|
||||
if "arguments" in func:
|
||||
tool_calls_buffer[index]["arguments"] += func["arguments"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM call error: {e}")
|
||||
raise
|
||||
|
||||
# Parse tool calls
|
||||
tool_calls = []
|
||||
for idx in sorted(tool_calls_buffer.keys()):
|
||||
tc = tool_calls_buffer[idx]
|
||||
try:
|
||||
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse tool arguments: {tc['arguments']}")
|
||||
arguments = {}
|
||||
|
||||
tool_calls.append({
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
# Add assistant message to history (Claude format uses content blocks)
|
||||
assistant_msg = {"role": "assistant", "content": []}
|
||||
|
||||
# Add text content block if present
|
||||
if full_content:
|
||||
assistant_msg["content"].append({
|
||||
"type": "text",
|
||||
"text": full_content
|
||||
})
|
||||
|
||||
# Add tool_use blocks if present
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
assistant_msg["content"].append({
|
||||
"type": "tool_use",
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"input": tc["arguments"]
|
||||
})
|
||||
|
||||
# Only append if content is not empty
|
||||
if assistant_msg["content"]:
|
||||
self.messages.append(assistant_msg)
|
||||
|
||||
self._emit_event("message_end", {
|
||||
"content": full_content,
|
||||
"tool_calls": tool_calls
|
||||
})
|
||||
|
||||
return full_content, tool_calls
|
||||
|
||||
def _execute_tool(self, tool_call: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute tool
|
||||
|
||||
Args:
|
||||
tool_call: {"id": str, "name": str, "arguments": dict}
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
tool_name = tool_call["name"]
|
||||
tool_id = tool_call["id"]
|
||||
arguments = tool_call["arguments"]
|
||||
|
||||
self._emit_event("tool_execution_start", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
try:
|
||||
tool = self.tools.get(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
# Set tool context
|
||||
tool.model = self.model
|
||||
tool.context = self.agent
|
||||
|
||||
# Execute tool
|
||||
start_time = time.time()
|
||||
result: ToolResult = tool.execute_tool(arguments)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
result_dict = {
|
||||
"status": result.status,
|
||||
"result": result.result,
|
||||
"execution_time": execution_time
|
||||
}
|
||||
|
||||
# Auto-refresh skills after skill creation
|
||||
if tool_name == "bash" and result.status == "success":
|
||||
command = arguments.get("command", "")
|
||||
if "init_skill.py" in command and self.agent.skill_manager:
|
||||
logger.info("🔄 Detected skill creation, refreshing skills...")
|
||||
self.agent.refresh_skills()
|
||||
logger.info(f"✅ Skills refreshed! Now have {len(self.agent.skill_manager.skills)} skills")
|
||||
|
||||
self._emit_event("tool_execution_end", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
**result_dict
|
||||
})
|
||||
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"result": str(e),
|
||||
"execution_time": 0
|
||||
}
|
||||
self._emit_event("tool_execution_end", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
**error_result
|
||||
})
|
||||
return error_result
|
||||
|
||||
def _trim_messages(self):
|
||||
"""
|
||||
Trim message history to stay within context limits.
|
||||
Uses agent's context management configuration.
|
||||
"""
|
||||
if not self.messages or not self.agent:
|
||||
return
|
||||
|
||||
# Get context window and reserve tokens from agent
|
||||
context_window = self.agent._get_model_context_window()
|
||||
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||
max_tokens = context_window - reserve_tokens
|
||||
|
||||
# Estimate current tokens
|
||||
current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages)
|
||||
|
||||
# Add system prompt tokens
|
||||
system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt})
|
||||
current_tokens += system_tokens
|
||||
|
||||
# If under limit, no need to trim
|
||||
if current_tokens <= max_tokens:
|
||||
return
|
||||
|
||||
# Keep messages from newest, accumulating tokens
|
||||
available_tokens = max_tokens - system_tokens
|
||||
kept_messages = []
|
||||
accumulated_tokens = 0
|
||||
|
||||
for msg in reversed(self.messages):
|
||||
msg_tokens = self.agent._estimate_message_tokens(msg)
|
||||
if accumulated_tokens + msg_tokens <= available_tokens:
|
||||
kept_messages.insert(0, msg)
|
||||
accumulated_tokens += msg_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
old_count = len(self.messages)
|
||||
self.messages = kept_messages
|
||||
new_count = len(self.messages)
|
||||
|
||||
if old_count > new_count:
|
||||
logger.info(
|
||||
f"Context trimmed: {old_count} -> {new_count} messages "
|
||||
f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, "
|
||||
f"limit: {max_tokens})"
|
||||
)
|
||||
|
||||
def _prepare_messages(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Prepare messages to send to LLM
|
||||
|
||||
Note: For Claude API, system prompt should be passed separately via system parameter,
|
||||
not as a message. The AgentLLMModel will handle this.
|
||||
"""
|
||||
# Don't add system message here - it will be handled separately by the LLM adapter
|
||||
return self.messages
|
||||
27
agent/protocol/context.py
Normal file
27
agent/protocol/context.py
Normal file
@@ -0,0 +1,27 @@
|
||||
class TeamContext:
|
||||
def __init__(self, name: str, description: str, rule: str, agents: list, max_steps: int = 100):
|
||||
"""
|
||||
Initialize the TeamContext with a name, description, rules, a list of agents, and a user question.
|
||||
:param name: The name of the group context.
|
||||
:param description: A description of the group context.
|
||||
:param rule: The rules governing the group context.
|
||||
:param agents: A list of agents in the context.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.rule = rule
|
||||
self.agents = agents
|
||||
self.user_task = "" # For backward compatibility
|
||||
self.task = None # Will be a Task instance
|
||||
self.model = None # Will be an instance of LLMModel
|
||||
self.task_short_name = None # Store the task directory name
|
||||
# List of agents that have been executed
|
||||
self.agent_outputs: list = []
|
||||
self.current_steps = 0
|
||||
self.max_steps = max_steps
|
||||
|
||||
|
||||
class AgentOutput:
|
||||
def __init__(self, agent_name: str, output: str):
|
||||
self.agent_name = agent_name
|
||||
self.output = output
|
||||
57
agent/protocol/models.py
Normal file
57
agent/protocol/models.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Models module for agent system.
|
||||
Provides basic model classes needed by tools and bridge integration.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
"""Request model for LLM operations"""
|
||||
|
||||
def __init__(self, messages: List[Dict[str, str]] = None, model: Optional[str] = None,
|
||||
temperature: float = 0.7, max_tokens: Optional[int] = None,
|
||||
stream: bool = False, tools: Optional[List] = None, **kwargs):
|
||||
self.messages = messages or []
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.tools = tools
|
||||
# Allow extra attributes
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class LLMModel:
|
||||
"""Base class for LLM models"""
|
||||
|
||||
def __init__(self, model: str = None, **kwargs):
|
||||
self.model = model
|
||||
self.config = kwargs
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with a request.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call not implemented in this context")
|
||||
|
||||
def call_stream(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with streaming.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call_stream not implemented in this context")
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
"""Factory for creating model instances"""
|
||||
|
||||
@staticmethod
|
||||
def create_model(model_type: str, **kwargs):
|
||||
"""
|
||||
Create a model instance based on type.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("ModelFactory.create_model not implemented in this context")
|
||||
96
agent/protocol/result.py
Normal file
96
agent/protocol/result.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from agent.protocol.task import Task, TaskStatus
|
||||
|
||||
|
||||
class AgentActionType(Enum):
|
||||
"""Enum representing different types of agent actions."""
|
||||
TOOL_USE = "tool_use"
|
||||
THINKING = "thinking"
|
||||
FINAL_ANSWER = "final_answer"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""
|
||||
Represents the result of a tool use.
|
||||
|
||||
Attributes:
|
||||
tool_name: Name of the tool used
|
||||
input_params: Parameters passed to the tool
|
||||
output: Output from the tool
|
||||
status: Status of the tool execution (success/error)
|
||||
error_message: Error message if the tool execution failed
|
||||
execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_name: str
|
||||
input_params: Dict[str, Any]
|
||||
output: Any
|
||||
status: str
|
||||
error_message: Optional[str] = None
|
||||
execution_time: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""
|
||||
Represents an action taken by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the action
|
||||
agent_id: ID of the agent that performed the action
|
||||
agent_name: Name of the agent that performed the action
|
||||
action_type: Type of action (tool use, thinking, final answer)
|
||||
content: Content of the action (thought content, final answer content)
|
||||
tool_result: Tool use details if action_type is TOOL_USE
|
||||
timestamp: When the action was performed
|
||||
"""
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
action_type: AgentActionType
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
tool_result: Optional[ToolResult] = None
|
||||
thought: Optional[str] = None
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""
|
||||
Represents the result of an agent's execution.
|
||||
|
||||
Attributes:
|
||||
final_answer: The final answer provided by the agent
|
||||
step_count: Number of steps taken by the agent
|
||||
status: Status of the execution (success/error)
|
||||
error_message: Error message if execution failed
|
||||
"""
|
||||
final_answer: str
|
||||
step_count: int
|
||||
status: str = "success"
|
||||
error_message: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def success(cls, final_answer: str, step_count: int) -> "AgentResult":
|
||||
"""Create a successful result"""
|
||||
return cls(final_answer=final_answer, step_count=step_count)
|
||||
|
||||
@classmethod
|
||||
def error(cls, error_message: str, step_count: int = 0) -> "AgentResult":
|
||||
"""Create an error result"""
|
||||
return cls(
|
||||
final_answer=f"Error: {error_message}",
|
||||
step_count=step_count,
|
||||
status="error",
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
@property
|
||||
def is_error(self) -> bool:
|
||||
"""Check if the result represents an error"""
|
||||
return self.status == "error"
|
||||
95
agent/protocol/task.py
Normal file
95
agent/protocol/task.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""Enum representing different types of tasks."""
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
FILE = "file"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Enum representing the status of a task."""
|
||||
INIT = "init" # Initial state
|
||||
PROCESSING = "processing" # In progress
|
||||
COMPLETED = "completed" # Completed
|
||||
FAILED = "failed" # Failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""
|
||||
Represents a task to be processed by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the task
|
||||
content: The primary text content of the task
|
||||
type: Type of the task
|
||||
status: Current status of the task
|
||||
created_at: Timestamp when the task was created
|
||||
updated_at: Timestamp when the task was last updated
|
||||
metadata: Additional metadata for the task
|
||||
images: List of image URLs or base64 encoded images
|
||||
videos: List of video URLs
|
||||
audios: List of audio URLs or base64 encoded audios
|
||||
files: List of file URLs or paths
|
||||
"""
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
type: TaskType = TaskType.TEXT
|
||||
status: TaskStatus = TaskStatus.INIT
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Media content
|
||||
images: List[str] = field(default_factory=list)
|
||||
videos: List[str] = field(default_factory=list)
|
||||
audios: List[str] = field(default_factory=list)
|
||||
files: List[str] = field(default_factory=list)
|
||||
|
||||
def __init__(self, content: str = "", **kwargs):
|
||||
"""
|
||||
Initialize a Task with content and optional keyword arguments.
|
||||
|
||||
Args:
|
||||
content: The text content of the task
|
||||
**kwargs: Additional attributes to set
|
||||
"""
|
||||
self.id = kwargs.get('id', str(uuid.uuid4()))
|
||||
self.content = content
|
||||
self.type = kwargs.get('type', TaskType.TEXT)
|
||||
self.status = kwargs.get('status', TaskStatus.INIT)
|
||||
self.created_at = kwargs.get('created_at', time.time())
|
||||
self.updated_at = kwargs.get('updated_at', time.time())
|
||||
self.metadata = kwargs.get('metadata', {})
|
||||
self.images = kwargs.get('images', [])
|
||||
self.videos = kwargs.get('videos', [])
|
||||
self.audios = kwargs.get('audios', [])
|
||||
self.files = kwargs.get('files', [])
|
||||
|
||||
def get_text(self) -> str:
|
||||
"""
|
||||
Get the text content of the task.
|
||||
|
||||
Returns:
|
||||
The text content
|
||||
"""
|
||||
return self.content
|
||||
|
||||
def update_status(self, status: TaskStatus) -> None:
|
||||
"""
|
||||
Update the status of the task.
|
||||
|
||||
Args:
|
||||
status: The new status
|
||||
"""
|
||||
self.status = status
|
||||
self.updated_at = time.time()
|
||||
29
agent/skills/__init__.py
Normal file
29
agent/skills/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Skills module for agent system.
|
||||
|
||||
This module provides the framework for loading, managing, and executing skills.
|
||||
Skills are markdown files with frontmatter that provide specialized instructions
|
||||
for specific tasks.
|
||||
"""
|
||||
|
||||
from agent.skills.types import (
|
||||
Skill,
|
||||
SkillEntry,
|
||||
SkillMetadata,
|
||||
SkillInstallSpec,
|
||||
LoadSkillsResult,
|
||||
)
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.manager import SkillManager
|
||||
from agent.skills.formatter import format_skills_for_prompt
|
||||
|
||||
__all__ = [
|
||||
"Skill",
|
||||
"SkillEntry",
|
||||
"SkillMetadata",
|
||||
"SkillInstallSpec",
|
||||
"LoadSkillsResult",
|
||||
"SkillLoader",
|
||||
"SkillManager",
|
||||
"format_skills_for_prompt",
|
||||
]
|
||||
211
agent/skills/config.py
Normal file
211
agent/skills/config.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Configuration support for skills.
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
from typing import Dict, Optional, List
|
||||
from agent.skills.types import SkillEntry
|
||||
|
||||
|
||||
def resolve_runtime_platform() -> str:
|
||||
"""Get the current runtime platform."""
|
||||
return platform.system().lower()
|
||||
|
||||
|
||||
def has_binary(bin_name: str) -> bool:
|
||||
"""
|
||||
Check if a binary is available in PATH.
|
||||
|
||||
:param bin_name: Binary name to check
|
||||
:return: True if binary is available
|
||||
"""
|
||||
import shutil
|
||||
return shutil.which(bin_name) is not None
|
||||
|
||||
|
||||
def has_any_binary(bin_names: List[str]) -> bool:
|
||||
"""
|
||||
Check if any of the given binaries is available.
|
||||
|
||||
:param bin_names: List of binary names to check
|
||||
:return: True if at least one binary is available
|
||||
"""
|
||||
return any(has_binary(bin_name) for bin_name in bin_names)
|
||||
|
||||
|
||||
def has_env_var(env_name: str) -> bool:
|
||||
"""
|
||||
Check if an environment variable is set.
|
||||
|
||||
:param env_name: Environment variable name
|
||||
:return: True if environment variable is set
|
||||
"""
|
||||
return env_name in os.environ and bool(os.environ[env_name].strip())
|
||||
|
||||
|
||||
def get_skill_config(config: Optional[Dict], skill_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get skill-specific configuration.
|
||||
|
||||
:param config: Global configuration dictionary
|
||||
:param skill_name: Name of the skill
|
||||
:return: Skill configuration or None
|
||||
"""
|
||||
if not config:
|
||||
return None
|
||||
|
||||
skills_config = config.get('skills', {})
|
||||
if not isinstance(skills_config, dict):
|
||||
return None
|
||||
|
||||
entries = skills_config.get('entries', {})
|
||||
if not isinstance(entries, dict):
|
||||
return None
|
||||
|
||||
return entries.get(skill_name)
|
||||
|
||||
|
||||
def should_include_skill(
|
||||
entry: SkillEntry,
|
||||
config: Optional[Dict] = None,
|
||||
current_platform: Optional[str] = None,
|
||||
lenient: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a skill should be included based on requirements.
|
||||
|
||||
Similar to clawdbot's shouldIncludeSkill logic, but with lenient mode:
|
||||
- In lenient mode (default): Only check explicit disable and platform, ignore missing requirements
|
||||
- In strict mode: Check all requirements (binary, env vars, config)
|
||||
|
||||
:param entry: SkillEntry to check
|
||||
:param config: Configuration dictionary
|
||||
:param current_platform: Current platform (default: auto-detect)
|
||||
:param lenient: If True, ignore missing requirements and load all skills (default: True)
|
||||
:return: True if skill should be included
|
||||
"""
|
||||
metadata = entry.metadata
|
||||
skill_name = entry.skill.name
|
||||
skill_config = get_skill_config(config, skill_name)
|
||||
|
||||
# Always check if skill is explicitly disabled in config
|
||||
if skill_config and skill_config.get('enabled') is False:
|
||||
return False
|
||||
|
||||
if not metadata:
|
||||
return True
|
||||
|
||||
# Always check platform requirements (can't work on wrong platform)
|
||||
if metadata.os:
|
||||
platform_name = current_platform or resolve_runtime_platform()
|
||||
# Map common platform names
|
||||
platform_map = {
|
||||
'darwin': 'darwin',
|
||||
'linux': 'linux',
|
||||
'windows': 'win32',
|
||||
}
|
||||
normalized_platform = platform_map.get(platform_name, platform_name)
|
||||
|
||||
if normalized_platform not in metadata.os:
|
||||
return False
|
||||
|
||||
# If skill has 'always: true', include it regardless of other requirements
|
||||
if metadata.always:
|
||||
return True
|
||||
|
||||
# In lenient mode, skip requirement checks and load all skills
|
||||
# Skills will fail gracefully at runtime if requirements are missing
|
||||
if lenient:
|
||||
return True
|
||||
|
||||
# Strict mode: Check all requirements
|
||||
if metadata.requires:
|
||||
# Check required binaries (all must be present)
|
||||
required_bins = metadata.requires.get('bins', [])
|
||||
if required_bins:
|
||||
if not all(has_binary(bin_name) for bin_name in required_bins):
|
||||
return False
|
||||
|
||||
# Check anyBins (at least one must be present)
|
||||
any_bins = metadata.requires.get('anyBins', [])
|
||||
if any_bins:
|
||||
if not has_any_binary(any_bins):
|
||||
return False
|
||||
|
||||
# Check environment variables (with config fallback)
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
for env_name in required_env:
|
||||
# Check in order: 1) env var, 2) skill config env, 3) skill config apiKey (if primaryEnv)
|
||||
if has_env_var(env_name):
|
||||
continue
|
||||
if skill_config:
|
||||
# Check skill config env dict
|
||||
skill_env = skill_config.get('env', {})
|
||||
if isinstance(skill_env, dict) and env_name in skill_env:
|
||||
continue
|
||||
# Check skill config apiKey (if this is the primaryEnv)
|
||||
if metadata.primary_env == env_name and skill_config.get('apiKey'):
|
||||
continue
|
||||
# Requirement not satisfied
|
||||
return False
|
||||
|
||||
# Check config paths
|
||||
required_config = metadata.requires.get('config', [])
|
||||
if required_config and config:
|
||||
for config_path in required_config:
|
||||
if not is_config_path_truthy(config, config_path):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_config_path_truthy(config: Dict, path: str) -> bool:
|
||||
"""
|
||||
Check if a config path resolves to a truthy value.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:param path: Dot-separated path (e.g., 'skills.enabled')
|
||||
:return: True if path resolves to truthy value
|
||||
"""
|
||||
parts = path.split('.')
|
||||
current = config
|
||||
|
||||
for part in parts:
|
||||
if not isinstance(current, dict):
|
||||
return False
|
||||
current = current.get(part)
|
||||
if current is None:
|
||||
return False
|
||||
|
||||
# Check if value is truthy
|
||||
if isinstance(current, bool):
|
||||
return current
|
||||
if isinstance(current, (int, float)):
|
||||
return current != 0
|
||||
if isinstance(current, str):
|
||||
return bool(current.strip())
|
||||
|
||||
return bool(current)
|
||||
|
||||
|
||||
def resolve_config_path(config: Dict, path: str):
|
||||
"""
|
||||
Resolve a dot-separated config path to its value.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:param path: Dot-separated path
|
||||
:return: Value at path or None
|
||||
"""
|
||||
parts = path.split('.')
|
||||
current = config
|
||||
|
||||
for part in parts:
|
||||
if not isinstance(current, dict):
|
||||
return None
|
||||
current = current.get(part)
|
||||
if current is None:
|
||||
return None
|
||||
|
||||
return current
|
||||
62
agent/skills/formatter.py
Normal file
62
agent/skills/formatter.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Skill formatter for generating prompts from skills.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
|
||||
|
||||
def format_skills_for_prompt(skills: List[Skill]) -> str:
|
||||
"""
|
||||
Format skills for inclusion in a system prompt.
|
||||
|
||||
Uses XML format per Agent Skills standard.
|
||||
Skills with disable_model_invocation=True are excluded.
|
||||
|
||||
:param skills: List of skills to format
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
# Filter out skills that should not be invoked by the model
|
||||
visible_skills = [s for s in skills if not s.disable_model_invocation]
|
||||
|
||||
if not visible_skills:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"\n\nThe following skills provide specialized instructions for specific tasks.",
|
||||
"Use the read tool to load a skill's file when the task matches its description.",
|
||||
"",
|
||||
"<available_skills>",
|
||||
]
|
||||
|
||||
for skill in visible_skills:
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <location>{_escape_xml(skill.file_path)}</location>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</available_skills>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
|
||||
"""
|
||||
Format skill entries for inclusion in a system prompt.
|
||||
|
||||
:param entries: List of skill entries to format
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
skills = [entry.skill for entry in entries]
|
||||
return format_skills_for_prompt(skills)
|
||||
|
||||
|
||||
def _escape_xml(text: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (text
|
||||
.replace('&', '&')
|
||||
.replace('<', '<')
|
||||
.replace('>', '>')
|
||||
.replace('"', '"')
|
||||
.replace("'", '''))
|
||||
159
agent/skills/frontmatter.py
Normal file
159
agent/skills/frontmatter.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Frontmatter parsing for skills.
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
from agent.skills.types import SkillMetadata, SkillInstallSpec
|
||||
|
||||
|
||||
def parse_frontmatter(content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse YAML-style frontmatter from markdown content.
|
||||
|
||||
Returns a dictionary of frontmatter fields.
|
||||
"""
|
||||
frontmatter = {}
|
||||
|
||||
# Match frontmatter block between --- markers
|
||||
match = re.match(r'^---\s*\n(.*?)\n---\s*\n', content, re.DOTALL)
|
||||
if not match:
|
||||
return frontmatter
|
||||
|
||||
frontmatter_text = match.group(1)
|
||||
|
||||
# Simple YAML-like parsing (supports key: value format)
|
||||
for line in frontmatter_text.split('\n'):
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
if ':' in line:
|
||||
key, value = line.split(':', 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Try to parse as JSON if it looks like JSON
|
||||
if value.startswith('{') or value.startswith('['):
|
||||
try:
|
||||
value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
# Parse boolean values
|
||||
elif value.lower() in ('true', 'false'):
|
||||
value = value.lower() == 'true'
|
||||
# Parse numbers
|
||||
elif value.isdigit():
|
||||
value = int(value)
|
||||
|
||||
frontmatter[key] = value
|
||||
|
||||
return frontmatter
|
||||
|
||||
|
||||
def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
"""
|
||||
Parse skill metadata from frontmatter.
|
||||
|
||||
Looks for 'metadata' field containing JSON with skill configuration.
|
||||
"""
|
||||
metadata_raw = frontmatter.get('metadata')
|
||||
if not metadata_raw:
|
||||
return None
|
||||
|
||||
# If it's a string, try to parse as JSON
|
||||
if isinstance(metadata_raw, str):
|
||||
try:
|
||||
metadata_raw = json.loads(metadata_raw)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
if not isinstance(metadata_raw, dict):
|
||||
return None
|
||||
|
||||
# Support both 'moltbot' and 'cow' keys for compatibility
|
||||
meta_obj = metadata_raw.get('moltbot') or metadata_raw.get('cow')
|
||||
if not meta_obj or not isinstance(meta_obj, dict):
|
||||
return None
|
||||
|
||||
# Parse install specs
|
||||
install_specs = []
|
||||
install_raw = meta_obj.get('install', [])
|
||||
if isinstance(install_raw, list):
|
||||
for spec_raw in install_raw:
|
||||
if not isinstance(spec_raw, dict):
|
||||
continue
|
||||
|
||||
kind = spec_raw.get('kind', spec_raw.get('type', '')).lower()
|
||||
if not kind:
|
||||
continue
|
||||
|
||||
spec = SkillInstallSpec(
|
||||
kind=kind,
|
||||
id=spec_raw.get('id'),
|
||||
label=spec_raw.get('label'),
|
||||
bins=_normalize_string_list(spec_raw.get('bins')),
|
||||
os=_normalize_string_list(spec_raw.get('os')),
|
||||
formula=spec_raw.get('formula'),
|
||||
package=spec_raw.get('package'),
|
||||
module=spec_raw.get('module'),
|
||||
url=spec_raw.get('url'),
|
||||
archive=spec_raw.get('archive'),
|
||||
extract=spec_raw.get('extract', False),
|
||||
strip_components=spec_raw.get('stripComponents'),
|
||||
target_dir=spec_raw.get('targetDir'),
|
||||
)
|
||||
install_specs.append(spec)
|
||||
|
||||
# Parse requires
|
||||
requires = {}
|
||||
requires_raw = meta_obj.get('requires', {})
|
||||
if isinstance(requires_raw, dict):
|
||||
for key, value in requires_raw.items():
|
||||
requires[key] = _normalize_string_list(value)
|
||||
|
||||
return SkillMetadata(
|
||||
always=meta_obj.get('always', False),
|
||||
skill_key=meta_obj.get('skillKey'),
|
||||
primary_env=meta_obj.get('primaryEnv'),
|
||||
emoji=meta_obj.get('emoji'),
|
||||
homepage=meta_obj.get('homepage'),
|
||||
os=_normalize_string_list(meta_obj.get('os')),
|
||||
requires=requires,
|
||||
install=install_specs,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_string_list(value: Any) -> List[str]:
|
||||
"""Normalize a value to a list of strings."""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
if isinstance(value, list):
|
||||
return [str(v).strip() for v in value if v]
|
||||
|
||||
if isinstance(value, str):
|
||||
return [v.strip() for v in value.split(',') if v.strip()]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def parse_boolean_value(value: Optional[str], default: bool = False) -> bool:
|
||||
"""Parse a boolean value from frontmatter."""
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ('true', '1', 'yes', 'on')
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def get_frontmatter_value(frontmatter: Dict[str, Any], key: str) -> Optional[str]:
|
||||
"""Get a frontmatter value as a string."""
|
||||
value = frontmatter.get(key)
|
||||
return str(value) if value is not None else None
|
||||
242
agent/skills/loader.py
Normal file
242
agent/skills/loader.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Skill loader for discovering and loading skills from directories.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry, LoadSkillsResult, SkillMetadata
|
||||
from agent.skills.frontmatter import parse_frontmatter, parse_metadata, parse_boolean_value, get_frontmatter_value
|
||||
|
||||
|
||||
class SkillLoader:
|
||||
"""Loads skills from various directories."""
|
||||
|
||||
def __init__(self, workspace_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the skill loader.
|
||||
|
||||
:param workspace_dir: Agent workspace directory (for workspace-specific skills)
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
|
||||
def load_skills_from_dir(self, dir_path: str, source: str) -> LoadSkillsResult:
|
||||
"""
|
||||
Load skills from a directory.
|
||||
|
||||
Discovery rules:
|
||||
- Direct .md files in the root directory
|
||||
- Recursive SKILL.md files under subdirectories
|
||||
|
||||
:param dir_path: Directory path to scan
|
||||
:param source: Source identifier (e.g., 'managed', 'workspace', 'bundled')
|
||||
:return: LoadSkillsResult with skills and diagnostics
|
||||
"""
|
||||
skills = []
|
||||
diagnostics = []
|
||||
|
||||
if not os.path.exists(dir_path):
|
||||
diagnostics.append(f"Directory does not exist: {dir_path}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
if not os.path.isdir(dir_path):
|
||||
diagnostics.append(f"Path is not a directory: {dir_path}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
# Load skills from root-level .md files and subdirectories
|
||||
result = self._load_skills_recursive(dir_path, source, include_root_files=True)
|
||||
|
||||
return result
|
||||
|
||||
def _load_skills_recursive(
|
||||
self,
|
||||
dir_path: str,
|
||||
source: str,
|
||||
include_root_files: bool = False
|
||||
) -> LoadSkillsResult:
|
||||
"""
|
||||
Recursively load skills from a directory.
|
||||
|
||||
:param dir_path: Directory to scan
|
||||
:param source: Source identifier
|
||||
:param include_root_files: Whether to include root-level .md files
|
||||
:return: LoadSkillsResult
|
||||
"""
|
||||
skills = []
|
||||
diagnostics = []
|
||||
|
||||
try:
|
||||
entries = os.listdir(dir_path)
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to list directory {dir_path}: {e}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
for entry in entries:
|
||||
# Skip hidden files and directories
|
||||
if entry.startswith('.'):
|
||||
continue
|
||||
|
||||
# Skip common non-skill directories
|
||||
if entry in ('node_modules', '__pycache__', 'venv', '.git'):
|
||||
continue
|
||||
|
||||
full_path = os.path.join(dir_path, entry)
|
||||
|
||||
# Handle directories
|
||||
if os.path.isdir(full_path):
|
||||
# Recursively scan subdirectories
|
||||
sub_result = self._load_skills_recursive(full_path, source, include_root_files=False)
|
||||
skills.extend(sub_result.skills)
|
||||
diagnostics.extend(sub_result.diagnostics)
|
||||
continue
|
||||
|
||||
# Handle files
|
||||
if not os.path.isfile(full_path):
|
||||
continue
|
||||
|
||||
# Check if this is a skill file
|
||||
is_root_md = include_root_files and entry.endswith('.md')
|
||||
is_skill_md = not include_root_files and entry == 'SKILL.md'
|
||||
|
||||
if not (is_root_md or is_skill_md):
|
||||
continue
|
||||
|
||||
# Load the skill
|
||||
skill_result = self._load_skill_from_file(full_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
diagnostics.extend(skill_result.diagnostics)
|
||||
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
def _load_skill_from_file(self, file_path: str, source: str) -> LoadSkillsResult:
|
||||
"""
|
||||
Load a single skill from a markdown file.
|
||||
|
||||
:param file_path: Path to the skill markdown file
|
||||
:param source: Source identifier
|
||||
:return: LoadSkillsResult
|
||||
"""
|
||||
diagnostics = []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to read skill file {file_path}: {e}")
|
||||
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
|
||||
|
||||
# Parse frontmatter
|
||||
frontmatter = parse_frontmatter(content)
|
||||
|
||||
# Get skill name and description
|
||||
skill_dir = os.path.dirname(file_path)
|
||||
parent_dir_name = os.path.basename(skill_dir)
|
||||
|
||||
name = frontmatter.get('name', parent_dir_name)
|
||||
description = frontmatter.get('description', '')
|
||||
|
||||
if not description or not description.strip():
|
||||
diagnostics.append(f"Skill {name} has no description: {file_path}")
|
||||
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
|
||||
|
||||
# Parse disable-model-invocation flag
|
||||
disable_model_invocation = parse_boolean_value(
|
||||
get_frontmatter_value(frontmatter, 'disable-model-invocation'),
|
||||
default=False
|
||||
)
|
||||
|
||||
# Create skill object
|
||||
skill = Skill(
|
||||
name=name,
|
||||
description=description,
|
||||
file_path=file_path,
|
||||
base_dir=skill_dir,
|
||||
source=source,
|
||||
content=content,
|
||||
disable_model_invocation=disable_model_invocation,
|
||||
frontmatter=frontmatter,
|
||||
)
|
||||
|
||||
return LoadSkillsResult(skills=[skill], diagnostics=diagnostics)
|
||||
|
||||
def load_all_skills(
|
||||
self,
|
||||
managed_dir: Optional[str] = None,
|
||||
workspace_skills_dir: Optional[str] = None,
|
||||
extra_dirs: Optional[List[str]] = None,
|
||||
) -> Dict[str, SkillEntry]:
|
||||
"""
|
||||
Load skills from all configured locations with precedence.
|
||||
|
||||
Precedence (lowest to highest):
|
||||
1. Extra directories
|
||||
2. Managed skills directory
|
||||
3. Workspace skills directory
|
||||
|
||||
:param managed_dir: Managed skills directory (e.g., ~/.cow/skills)
|
||||
:param workspace_skills_dir: Workspace skills directory (e.g., workspace/skills)
|
||||
:param extra_dirs: Additional directories to load skills from
|
||||
:return: Dictionary mapping skill name to SkillEntry
|
||||
"""
|
||||
skill_map: Dict[str, SkillEntry] = {}
|
||||
all_diagnostics = []
|
||||
|
||||
# Load from extra directories (lowest precedence)
|
||||
if extra_dirs:
|
||||
for extra_dir in extra_dirs:
|
||||
if not os.path.exists(extra_dir):
|
||||
continue
|
||||
result = self.load_skills_from_dir(extra_dir, source='extra')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Load from managed directory
|
||||
if managed_dir and os.path.exists(managed_dir):
|
||||
result = self.load_skills_from_dir(managed_dir, source='managed')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Load from workspace directory (highest precedence)
|
||||
if workspace_skills_dir and os.path.exists(workspace_skills_dir):
|
||||
result = self.load_skills_from_dir(workspace_skills_dir, source='workspace')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Log diagnostics
|
||||
if all_diagnostics:
|
||||
logger.debug(f"Skill loading diagnostics: {len(all_diagnostics)} issues")
|
||||
for diag in all_diagnostics[:5]: # Log first 5
|
||||
logger.debug(f" - {diag}")
|
||||
|
||||
logger.info(f"Loaded {len(skill_map)} skills from all sources")
|
||||
|
||||
return skill_map
|
||||
|
||||
def _create_skill_entry(self, skill: Skill) -> SkillEntry:
|
||||
"""
|
||||
Create a SkillEntry from a Skill with parsed metadata.
|
||||
|
||||
:param skill: The skill to create an entry for
|
||||
:return: SkillEntry with metadata
|
||||
"""
|
||||
metadata = parse_metadata(skill.frontmatter)
|
||||
|
||||
# Parse user-invocable flag
|
||||
user_invocable = parse_boolean_value(
|
||||
get_frontmatter_value(skill.frontmatter, 'user-invocable'),
|
||||
default=True
|
||||
)
|
||||
|
||||
return SkillEntry(
|
||||
skill=skill,
|
||||
metadata=metadata,
|
||||
user_invocable=user_invocable,
|
||||
)
|
||||
214
agent/skills/manager.py
Normal file
214
agent/skills/manager.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Skill manager for managing skill lifecycle and operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry, SkillSnapshot
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.formatter import format_skill_entries_for_prompt
|
||||
|
||||
|
||||
class SkillManager:
|
||||
"""Manages skills for an agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Optional[str] = None,
|
||||
managed_skills_dir: Optional[str] = None,
|
||||
extra_dirs: Optional[List[str]] = None,
|
||||
config: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the skill manager.
|
||||
|
||||
:param workspace_dir: Agent workspace directory
|
||||
:param managed_skills_dir: Managed skills directory (e.g., ~/.cow/skills)
|
||||
:param extra_dirs: Additional skill directories
|
||||
:param config: Configuration dictionary
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.managed_skills_dir = managed_skills_dir or self._get_default_managed_dir()
|
||||
self.extra_dirs = extra_dirs or []
|
||||
self.config = config or {}
|
||||
|
||||
self.loader = SkillLoader(workspace_dir=workspace_dir)
|
||||
self.skills: Dict[str, SkillEntry] = {}
|
||||
|
||||
# Load skills on initialization
|
||||
self.refresh_skills()
|
||||
|
||||
def _get_default_managed_dir(self) -> str:
|
||||
"""Get the default managed skills directory."""
|
||||
# Use project root skills directory as default
|
||||
import os
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
return os.path.join(project_root, 'skills')
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Reload all skills from configured directories."""
|
||||
workspace_skills_dir = None
|
||||
if self.workspace_dir:
|
||||
workspace_skills_dir = os.path.join(self.workspace_dir, 'skills')
|
||||
|
||||
self.skills = self.loader.load_all_skills(
|
||||
managed_dir=self.managed_skills_dir,
|
||||
workspace_skills_dir=workspace_skills_dir,
|
||||
extra_dirs=self.extra_dirs,
|
||||
)
|
||||
|
||||
logger.info(f"SkillManager: Loaded {len(self.skills)} skills")
|
||||
|
||||
def get_skill(self, name: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
Get a skill by name.
|
||||
|
||||
:param name: Skill name
|
||||
:return: SkillEntry or None if not found
|
||||
"""
|
||||
return self.skills.get(name)
|
||||
|
||||
def list_skills(self) -> List[SkillEntry]:
|
||||
"""
|
||||
Get all loaded skills.
|
||||
|
||||
:return: List of all skill entries
|
||||
"""
|
||||
return list(self.skills.values())
|
||||
|
||||
def filter_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
include_disabled: bool = False,
|
||||
check_requirements: bool = False, # Changed default to False for lenient loading
|
||||
lenient: bool = True, # New parameter for lenient mode
|
||||
) -> List[SkillEntry]:
|
||||
"""
|
||||
Filter skills based on criteria.
|
||||
|
||||
By default (lenient=True), all skills are loaded regardless of missing requirements.
|
||||
Skills will fail gracefully at runtime if requirements are not met.
|
||||
|
||||
:param skill_filter: List of skill names to include (None = all)
|
||||
:param include_disabled: Whether to include skills with disable_model_invocation=True
|
||||
:param check_requirements: Whether to check skill requirements (default: False)
|
||||
:param lenient: If True, ignore missing requirements (default: True)
|
||||
:return: Filtered list of skill entries
|
||||
"""
|
||||
from agent.skills.config import should_include_skill
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Check requirements (platform, explicit disable, etc.)
|
||||
# In lenient mode, only checks platform and explicit disable
|
||||
if check_requirements or not lenient:
|
||||
entries = [e for e in entries if should_include_skill(e, self.config, lenient=lenient)]
|
||||
else:
|
||||
# Lenient mode: only check explicit disable and platform
|
||||
entries = [e for e in entries if should_include_skill(e, self.config, lenient=True)]
|
||||
|
||||
# Apply skill filter
|
||||
if skill_filter is not None:
|
||||
normalized = [name.strip() for name in skill_filter if name.strip()]
|
||||
if normalized:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Filter out disabled skills unless explicitly requested
|
||||
if not include_disabled:
|
||||
entries = [e for e in entries if not e.skill.disable_model_invocation]
|
||||
|
||||
return entries
|
||||
|
||||
def build_skills_prompt(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a formatted prompt containing available skills.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt
|
||||
"""
|
||||
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
return format_skill_entries_for_prompt(entries)
|
||||
|
||||
def build_skill_snapshot(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
version: Optional[int] = None,
|
||||
) -> SkillSnapshot:
|
||||
"""
|
||||
Build a snapshot of skills for a specific run.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:param version: Optional version number for the snapshot
|
||||
:return: SkillSnapshot
|
||||
"""
|
||||
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
prompt = format_skill_entries_for_prompt(entries)
|
||||
|
||||
skills_info = []
|
||||
resolved_skills = []
|
||||
|
||||
for entry in entries:
|
||||
skills_info.append({
|
||||
'name': entry.skill.name,
|
||||
'primary_env': entry.metadata.primary_env if entry.metadata else None,
|
||||
})
|
||||
resolved_skills.append(entry.skill)
|
||||
|
||||
return SkillSnapshot(
|
||||
prompt=prompt,
|
||||
skills=skills_info,
|
||||
resolved_skills=resolved_skills,
|
||||
version=version,
|
||||
)
|
||||
|
||||
def sync_skills_to_workspace(self, target_workspace_dir: str):
|
||||
"""
|
||||
Sync all loaded skills to a target workspace directory.
|
||||
|
||||
This is useful for sandbox environments where skills need to be copied.
|
||||
|
||||
:param target_workspace_dir: Target workspace directory
|
||||
"""
|
||||
import shutil
|
||||
|
||||
target_skills_dir = os.path.join(target_workspace_dir, 'skills')
|
||||
|
||||
# Remove existing skills directory
|
||||
if os.path.exists(target_skills_dir):
|
||||
shutil.rmtree(target_skills_dir)
|
||||
|
||||
# Create new skills directory
|
||||
os.makedirs(target_skills_dir, exist_ok=True)
|
||||
|
||||
# Copy each skill
|
||||
for entry in self.skills.values():
|
||||
skill_name = entry.skill.name
|
||||
source_dir = entry.skill.base_dir
|
||||
target_dir = os.path.join(target_skills_dir, skill_name)
|
||||
|
||||
try:
|
||||
shutil.copytree(source_dir, target_dir)
|
||||
logger.debug(f"Synced skill '{skill_name}' to {target_dir}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync skill '{skill_name}': {e}")
|
||||
|
||||
logger.info(f"Synced {len(self.skills)} skills to {target_skills_dir}")
|
||||
|
||||
def get_skill_by_key(self, skill_key: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
Get a skill by its skill key (which may differ from name).
|
||||
|
||||
:param skill_key: Skill key to look up
|
||||
:return: SkillEntry or None
|
||||
"""
|
||||
for entry in self.skills.values():
|
||||
if entry.metadata and entry.metadata.skill_key == skill_key:
|
||||
return entry
|
||||
if entry.skill.name == skill_key:
|
||||
return entry
|
||||
return None
|
||||
74
agent/skills/types.py
Normal file
74
agent/skills/types.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Type definitions for skills system.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillInstallSpec:
|
||||
"""Specification for installing skill dependencies."""
|
||||
kind: str # brew, pip, npm, download, etc.
|
||||
id: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
bins: List[str] = field(default_factory=list)
|
||||
os: List[str] = field(default_factory=list)
|
||||
formula: Optional[str] = None # for brew
|
||||
package: Optional[str] = None # for pip/npm
|
||||
module: Optional[str] = None
|
||||
url: Optional[str] = None # for download
|
||||
archive: Optional[str] = None
|
||||
extract: bool = False
|
||||
strip_components: Optional[int] = None
|
||||
target_dir: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillMetadata:
|
||||
"""Metadata for a skill from frontmatter."""
|
||||
always: bool = False # Always include this skill
|
||||
skill_key: Optional[str] = None # Override skill key
|
||||
primary_env: Optional[str] = None # Primary environment variable
|
||||
emoji: Optional[str] = None
|
||||
homepage: Optional[str] = None
|
||||
os: List[str] = field(default_factory=list) # Supported OS platforms
|
||||
requires: Dict[str, List[str]] = field(default_factory=dict) # Requirements
|
||||
install: List[SkillInstallSpec] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Skill:
|
||||
"""Represents a skill loaded from a markdown file."""
|
||||
name: str
|
||||
description: str
|
||||
file_path: str
|
||||
base_dir: str
|
||||
source: str # managed, workspace, bundled, etc.
|
||||
content: str # Full markdown content
|
||||
disable_model_invocation: bool = False
|
||||
frontmatter: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillEntry:
|
||||
"""A skill with parsed metadata."""
|
||||
skill: Skill
|
||||
metadata: Optional[SkillMetadata] = None
|
||||
user_invocable: bool = True # Can users invoke this skill directly
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadSkillsResult:
|
||||
"""Result of loading skills from a directory."""
|
||||
skills: List[Skill]
|
||||
diagnostics: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillSnapshot:
|
||||
"""Snapshot of skills for a specific run."""
|
||||
prompt: str # Formatted prompt text
|
||||
skills: List[Dict[str, str]] # List of skill info (name, primary_env)
|
||||
resolved_skills: List[Skill] = field(default_factory=list)
|
||||
version: Optional[int] = None
|
||||
105
agent/tools/__init__.py
Normal file
105
agent/tools/__init__.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Import base tool
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
|
||||
# Import basic tools (no external dependencies)
|
||||
from agent.tools.calculator.calculator import Calculator
|
||||
from agent.tools.current_time.current_time import CurrentTime
|
||||
|
||||
# Import file operation tools
|
||||
from agent.tools.read.read import Read
|
||||
from agent.tools.write.write import Write
|
||||
from agent.tools.edit.edit import Edit
|
||||
from agent.tools.bash.bash import Bash
|
||||
from agent.tools.grep.grep import Grep
|
||||
from agent.tools.find.find import Find
|
||||
from agent.tools.ls.ls import Ls
|
||||
|
||||
# Import memory tools
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
# Import web tools
|
||||
from agent.tools.web_fetch.web_fetch import WebFetch
|
||||
|
||||
# Import tools with optional dependencies
|
||||
def _import_optional_tools():
|
||||
"""Import tools that have optional dependencies"""
|
||||
tools = {}
|
||||
|
||||
# Google Search (requires requests)
|
||||
try:
|
||||
from agent.tools.google_search.google_search import GoogleSearch
|
||||
tools['GoogleSearch'] = GoogleSearch
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# File Save (may have dependencies)
|
||||
try:
|
||||
from agent.tools.file_save.file_save import FileSave
|
||||
tools['FileSave'] = FileSave
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Terminal (basic, should work)
|
||||
try:
|
||||
from agent.tools.terminal.terminal import Terminal
|
||||
tools['Terminal'] = Terminal
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return tools
|
||||
|
||||
# Load optional tools
|
||||
_optional_tools = _import_optional_tools()
|
||||
GoogleSearch = _optional_tools.get('GoogleSearch')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
Terminal = _optional_tools.get('Terminal')
|
||||
|
||||
|
||||
# Delayed import for BrowserTool
|
||||
def _import_browser_tool():
|
||||
try:
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
return BrowserTool
|
||||
except ImportError:
|
||||
# Return a placeholder class that will prompt the user to install dependencies when instantiated
|
||||
class BrowserToolPlaceholder:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"The 'browser-use' package is required to use BrowserTool. "
|
||||
"Please install it with 'pip install browser-use>=0.1.40'."
|
||||
)
|
||||
|
||||
return BrowserToolPlaceholder
|
||||
|
||||
|
||||
# Dynamically set BrowserTool
|
||||
BrowserTool = _import_browser_tool()
|
||||
|
||||
# Export all tools (including optional ones that might be None)
|
||||
__all__ = [
|
||||
'BaseTool',
|
||||
'ToolManager',
|
||||
'Calculator',
|
||||
'CurrentTime',
|
||||
'Read',
|
||||
'Write',
|
||||
'Edit',
|
||||
'Bash',
|
||||
'Grep',
|
||||
'Find',
|
||||
'Ls',
|
||||
'MemorySearchTool',
|
||||
'MemoryGetTool',
|
||||
'WebFetch',
|
||||
# Optional tools (may be None if dependencies not available)
|
||||
'GoogleSearch',
|
||||
'FileSave',
|
||||
'Terminal',
|
||||
'BrowserTool'
|
||||
]
|
||||
|
||||
"""
|
||||
Tools module for Agent.
|
||||
"""
|
||||
99
agent/tools/base_tool.py
Normal file
99
agent/tools/base_tool.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from common.log import logger
|
||||
import copy
|
||||
|
||||
|
||||
class ToolStage(Enum):
|
||||
"""Enum representing tool decision stages"""
|
||||
PRE_PROCESS = "pre_process" # Tools that need to be actively selected by the agent
|
||||
POST_PROCESS = "post_process" # Tools that automatically execute after final_answer
|
||||
|
||||
|
||||
class ToolResult:
|
||||
"""Tool execution result"""
|
||||
|
||||
def __init__(self, status: str = None, result: Any = None, ext_data: Any = None):
|
||||
self.status = status
|
||||
self.result = result
|
||||
self.ext_data = ext_data
|
||||
|
||||
@staticmethod
|
||||
def success(result, ext_data: Any = None):
|
||||
return ToolResult(status="success", result=result, ext_data=ext_data)
|
||||
|
||||
@staticmethod
|
||||
def fail(result, ext_data: Any = None):
|
||||
return ToolResult(status="error", result=result, ext_data=ext_data)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""Base class for all tools."""
|
||||
|
||||
# Default decision stage is pre-process
|
||||
stage = ToolStage.PRE_PROCESS
|
||||
|
||||
# Class attributes must be inherited
|
||||
name: str = "base_tool"
|
||||
description: str = "Base tool"
|
||||
params: dict = {} # Store JSON Schema
|
||||
model: Optional[Any] = None # LLM model instance, type depends on bot implementation
|
||||
|
||||
@classmethod
|
||||
def get_json_schema(cls) -> dict:
|
||||
"""Get the standard description of the tool"""
|
||||
return {
|
||||
"name": cls.name,
|
||||
"description": cls.description,
|
||||
"parameters": cls.params
|
||||
}
|
||||
|
||||
def execute_tool(self, params: dict) -> ToolResult:
|
||||
try:
|
||||
return self.execute(params)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
"""Specific logic to be implemented by subclasses"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _parse_schema(cls) -> dict:
|
||||
"""Convert JSON Schema to Pydantic fields"""
|
||||
fields = {}
|
||||
for name, prop in cls.params["properties"].items():
|
||||
# Convert JSON Schema types to Python types
|
||||
type_map = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict
|
||||
}
|
||||
fields[name] = (
|
||||
type_map[prop["type"]],
|
||||
prop.get("default", ...)
|
||||
)
|
||||
return fields
|
||||
|
||||
def should_auto_execute(self, context) -> bool:
|
||||
"""
|
||||
Determine if this tool should be automatically executed based on context.
|
||||
|
||||
:param context: The agent context
|
||||
:return: True if the tool should be executed, False otherwise
|
||||
"""
|
||||
# Only tools in post-process stage will be automatically executed
|
||||
return self.stage == ToolStage.POST_PROCESS
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close any resources used by the tool.
|
||||
This method should be overridden by tools that need to clean up resources
|
||||
such as browser connections, file handles, etc.
|
||||
|
||||
By default, this method does nothing.
|
||||
"""
|
||||
pass
|
||||
3
agent/tools/bash/__init__.py
Normal file
3
agent/tools/bash/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .bash import Bash
|
||||
|
||||
__all__ = ['Bash']
|
||||
186
agent/tools/bash/bash.py
Normal file
186
agent/tools/bash/bash.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Bash tool - Execute bash commands
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
class Bash(BaseTool):
|
||||
"""Tool for executing bash commands"""
|
||||
|
||||
name: str = "bash"
|
||||
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
|
||||
|
||||
IMPORTANT SAFETY GUIDELINES:
|
||||
- You can freely create, modify, and delete files within the current workspace
|
||||
- For operations outside the workspace or potentially destructive commands (rm -rf, system commands, etc.), always explain what you're about to do and ask for user confirmation first
|
||||
- When in doubt, describe the command's purpose and ask for permission before executing"""
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds (optional, default: 30)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
# Ensure working directory exists
|
||||
if not os.path.exists(self.cwd):
|
||||
os.makedirs(self.cwd, exist_ok=True)
|
||||
self.default_timeout = self.config.get("timeout", 30)
|
||||
# Enable safety mode by default (can be disabled in config)
|
||||
self.safety_mode = self.config.get("safety_mode", True)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute a bash command
|
||||
|
||||
:param args: Dictionary containing the command and optional timeout
|
||||
:return: Command output or error
|
||||
"""
|
||||
command = args.get("command", "").strip()
|
||||
timeout = args.get("timeout", self.default_timeout)
|
||||
|
||||
if not command:
|
||||
return ToolResult.fail("Error: command parameter is required")
|
||||
|
||||
# Optional safety check - only warn about extremely dangerous commands
|
||||
if self.safety_mode:
|
||||
warning = self._get_safety_warning(command)
|
||||
if warning:
|
||||
return ToolResult.fail(
|
||||
f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.")
|
||||
|
||||
try:
|
||||
# Execute command
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
cwd=self.cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Combine stdout and stderr
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
|
||||
# Check if we need to save full output to temp file
|
||||
temp_file_path = None
|
||||
total_bytes = len(output.encode('utf-8'))
|
||||
|
||||
if total_bytes > DEFAULT_MAX_BYTES:
|
||||
# Save full output to temp file
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log', prefix='bash-') as f:
|
||||
f.write(output)
|
||||
temp_file_path = f.name
|
||||
|
||||
# Apply tail truncation
|
||||
truncation = truncate_tail(output)
|
||||
output_text = truncation.content or "(no output)"
|
||||
|
||||
# Build result
|
||||
details = {}
|
||||
|
||||
if truncation.truncated:
|
||||
details["truncation"] = truncation.to_dict()
|
||||
if temp_file_path:
|
||||
details["full_output_path"] = temp_file_path
|
||||
|
||||
# Build notice
|
||||
start_line = truncation.total_lines - truncation.output_lines + 1
|
||||
end_line = truncation.total_lines
|
||||
|
||||
if truncation.last_line_partial:
|
||||
# Edge case: last line alone > 30KB
|
||||
last_line = output.split('\n')[-1] if output else ""
|
||||
last_line_size = format_size(len(last_line.encode('utf-8')))
|
||||
output_text += f"\n\n[Showing last {format_size(truncation.output_bytes)} of line {end_line} (line is {last_line_size}). Full output: {temp_file_path}]"
|
||||
elif truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines}. Full output: {temp_file_path}]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Full output: {temp_file_path}]"
|
||||
|
||||
# Check exit code
|
||||
if result.returncode != 0:
|
||||
output_text += f"\n\nCommand exited with code {result.returncode}"
|
||||
return ToolResult.fail({
|
||||
"output": output_text,
|
||||
"exit_code": result.returncode,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output_text,
|
||||
"exit_code": result.returncode,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResult.fail(f"Error: Command timed out after {timeout} seconds")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing command: {str(e)}")
|
||||
|
||||
def _get_safety_warning(self, command: str) -> str:
|
||||
"""
|
||||
Get safety warning for potentially dangerous commands
|
||||
Only warns about extremely dangerous system-level operations
|
||||
|
||||
:param command: Command to check
|
||||
:return: Warning message if dangerous, empty string if safe
|
||||
"""
|
||||
cmd_lower = command.lower().strip()
|
||||
|
||||
# Only block extremely dangerous system operations
|
||||
dangerous_patterns = [
|
||||
# System shutdown/reboot
|
||||
("shutdown", "This command will shut down the system"),
|
||||
("reboot", "This command will reboot the system"),
|
||||
("halt", "This command will halt the system"),
|
||||
("poweroff", "This command will power off the system"),
|
||||
|
||||
# Critical system modifications
|
||||
("rm -rf /", "This command will delete the entire filesystem"),
|
||||
("rm -rf /*", "This command will delete the entire filesystem"),
|
||||
("dd if=/dev/zero", "This command can destroy disk data"),
|
||||
("mkfs", "This command will format a filesystem, destroying all data"),
|
||||
("fdisk", "This command modifies disk partitions"),
|
||||
|
||||
# User/system management (only if targeting system users)
|
||||
("userdel root", "This command will delete the root user"),
|
||||
("passwd root", "This command will change the root password"),
|
||||
]
|
||||
|
||||
for pattern, warning in dangerous_patterns:
|
||||
if pattern in cmd_lower:
|
||||
return warning
|
||||
|
||||
# Check for recursive deletion outside workspace
|
||||
if "rm" in cmd_lower and "-rf" in cmd_lower:
|
||||
# Allow deletion within current workspace
|
||||
if not any(path in cmd_lower for path in ["./", self.cwd.lower()]):
|
||||
# Check if targeting system directories
|
||||
system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"]
|
||||
if any(sysdir in cmd_lower for sysdir in system_dirs):
|
||||
return "This command will recursively delete system directories"
|
||||
|
||||
return "" # No warning needed
|
||||
18
agent/tools/browser_tool.py
Normal file
18
agent/tools/browser_tool.py
Normal file
@@ -0,0 +1,18 @@
|
||||
def copy(self):
|
||||
"""
|
||||
Special copy method for browser tool to avoid recreating browser instance.
|
||||
|
||||
:return: A new instance with shared browser reference but unique model
|
||||
"""
|
||||
new_tool = self.__class__()
|
||||
|
||||
# Copy essential attributes
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, 'context', None)
|
||||
new_tool.config = getattr(self, 'config', None)
|
||||
|
||||
# Share the browser instance instead of creating a new one
|
||||
if hasattr(self, 'browser'):
|
||||
new_tool.browser = self.browser
|
||||
|
||||
return new_tool
|
||||
58
agent/tools/calculator/calculator.py
Normal file
58
agent/tools/calculator/calculator.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import math
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class Calculator(BaseTool):
|
||||
name: str = "calculator"
|
||||
description: str = "A tool to perform basic mathematical calculations."
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). "
|
||||
"Ensure your input is a valid Python expression, it will be evaluated directly."
|
||||
}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
config: dict = {}
|
||||
|
||||
def execute(self, args: dict) -> ToolResult:
|
||||
try:
|
||||
# Get the expression
|
||||
expression = args["expression"]
|
||||
|
||||
# Create a safe local environment containing only basic math functions
|
||||
safe_locals = {
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"max": max,
|
||||
"min": min,
|
||||
"pow": pow,
|
||||
"sqrt": math.sqrt,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"pi": math.pi,
|
||||
"e": math.e,
|
||||
"log": math.log,
|
||||
"log10": math.log10,
|
||||
"exp": math.exp,
|
||||
"floor": math.floor,
|
||||
"ceil": math.ceil
|
||||
}
|
||||
|
||||
# Safely evaluate the expression
|
||||
result = eval(expression, {"__builtins__": {}}, safe_locals)
|
||||
|
||||
return ToolResult.success({
|
||||
"result": result,
|
||||
"expression": expression
|
||||
})
|
||||
except Exception as e:
|
||||
return ToolResult.success({
|
||||
"error": str(e),
|
||||
"expression": args.get("expression", "")
|
||||
})
|
||||
75
agent/tools/current_time/current_time.py
Normal file
75
agent/tools/current_time/current_time.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class CurrentTime(BaseTool):
|
||||
name: str = "time"
|
||||
description: str = "A tool to get current date and time information."
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "Optional format for the time (e.g., 'iso', 'unix', 'human'). Default is 'human'."
|
||||
},
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone specification (e.g., 'UTC', 'local'). Default is 'local'."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
config: dict = {}
|
||||
|
||||
def execute(self, args: dict) -> ToolResult:
|
||||
try:
|
||||
# Get the format and timezone parameters, with defaults
|
||||
time_format = args.get("format", "human").lower()
|
||||
timezone = args.get("timezone", "local").lower()
|
||||
|
||||
# Get current time
|
||||
current_time = datetime.datetime.now()
|
||||
|
||||
# Handle timezone if specified
|
||||
if timezone == "utc":
|
||||
current_time = datetime.datetime.utcnow()
|
||||
|
||||
# Format the time according to the specified format
|
||||
if time_format == "iso":
|
||||
# ISO 8601 format
|
||||
formatted_time = current_time.isoformat()
|
||||
elif time_format == "unix":
|
||||
# Unix timestamp (seconds since epoch)
|
||||
formatted_time = time.time()
|
||||
else:
|
||||
# Human-readable format
|
||||
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Prepare additional time components for the response
|
||||
year = current_time.year
|
||||
month = current_time.month
|
||||
day = current_time.day
|
||||
hour = current_time.hour
|
||||
minute = current_time.minute
|
||||
second = current_time.second
|
||||
weekday = current_time.strftime("%A") # Full weekday name
|
||||
|
||||
result = {
|
||||
"current_time": formatted_time,
|
||||
"components": {
|
||||
"year": year,
|
||||
"month": month,
|
||||
"day": day,
|
||||
"hour": hour,
|
||||
"minute": minute,
|
||||
"second": second,
|
||||
"weekday": weekday
|
||||
},
|
||||
"format": time_format,
|
||||
"timezone": timezone
|
||||
}
|
||||
return ToolResult.success(result=result)
|
||||
except Exception as e:
|
||||
return ToolResult.fail(result=str(e))
|
||||
3
agent/tools/edit/__init__.py
Normal file
3
agent/tools/edit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .edit import Edit
|
||||
|
||||
__all__ = ['Edit']
|
||||
173
agent/tools/edit/edit.py
Normal file
173
agent/tools/edit/edit.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Edit tool - Precise file editing
|
||||
Edit files through exact text replacement
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
normalize_to_lf,
|
||||
restore_line_endings,
|
||||
normalize_for_fuzzy_match,
|
||||
fuzzy_find_text,
|
||||
generate_diff_string
|
||||
)
|
||||
|
||||
|
||||
class Edit(BaseTool):
|
||||
"""Tool for precise file editing"""
|
||||
|
||||
name: str = "edit"
|
||||
description: str = "Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit (relative or absolute)"
|
||||
},
|
||||
"oldText": {
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (must match exactly)"
|
||||
},
|
||||
"newText": {
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with"
|
||||
}
|
||||
},
|
||||
"required": ["path", "oldText", "newText"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.memory_manager = self.config.get("memory_manager", None)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file edit operation
|
||||
|
||||
:param args: Contains file path, old text and new text
|
||||
:return: Operation result
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
old_text = args.get("oldText", "")
|
||||
new_text = args.get("newText", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable/writable
|
||||
if not os.access(absolute_path, os.R_OK | os.W_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable/writable: {path}")
|
||||
|
||||
try:
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
raw_content = f.read()
|
||||
|
||||
# Remove BOM (LLM won't include invisible BOM in oldText)
|
||||
bom, content = strip_bom(raw_content)
|
||||
|
||||
# Detect original line ending
|
||||
original_ending = detect_line_ending(content)
|
||||
|
||||
# Normalize to LF
|
||||
normalized_content = normalize_to_lf(content)
|
||||
normalized_old_text = normalize_to_lf(old_text)
|
||||
normalized_new_text = normalize_to_lf(new_text)
|
||||
|
||||
# Use fuzzy matching to find old text (try exact match first, then fuzzy match)
|
||||
match_result = fuzzy_find_text(normalized_content, normalized_old_text)
|
||||
|
||||
if not match_result.found:
|
||||
return ToolResult.fail(
|
||||
f"Error: Could not find the exact text in {path}. "
|
||||
"The old text must match exactly including all whitespace and newlines."
|
||||
)
|
||||
|
||||
# Calculate occurrence count (use fuzzy normalized content for consistency)
|
||||
fuzzy_content = normalize_for_fuzzy_match(normalized_content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
|
||||
occurrences = fuzzy_content.count(fuzzy_old_text)
|
||||
|
||||
if occurrences > 1:
|
||||
return ToolResult.fail(
|
||||
f"Error: Found {occurrences} occurrences of the text in {path}. "
|
||||
"The text must be unique. Please provide more context to make it unique."
|
||||
)
|
||||
|
||||
# Execute replacement (use matched text position)
|
||||
base_content = match_result.content_for_replacement
|
||||
new_content = (
|
||||
base_content[:match_result.index] +
|
||||
normalized_new_text +
|
||||
base_content[match_result.index + match_result.match_length:]
|
||||
)
|
||||
|
||||
# Verify replacement actually changed content
|
||||
if base_content == new_content:
|
||||
return ToolResult.fail(
|
||||
f"Error: No changes made to {path}. "
|
||||
"The replacement produced identical content. "
|
||||
"This might indicate an issue with special characters or the text not existing as expected."
|
||||
)
|
||||
|
||||
# Restore original line endings
|
||||
final_content = bom + restore_line_endings(new_content, original_ending)
|
||||
|
||||
# Write file
|
||||
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||
f.write(final_content)
|
||||
|
||||
# Generate diff
|
||||
diff_result = generate_diff_string(base_content, new_content)
|
||||
|
||||
result = {
|
||||
"message": f"Successfully replaced text in {path}",
|
||||
"path": path,
|
||||
"diff": diff_result['diff'],
|
||||
"first_changed_line": diff_result['first_changed_line']
|
||||
}
|
||||
|
||||
# Notify memory manager if file is in memory directory
|
||||
if self.memory_manager and "memory/" in path:
|
||||
try:
|
||||
self.memory_manager.mark_dirty()
|
||||
except Exception as e:
|
||||
# Don't fail the edit if memory notification fails
|
||||
pass
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {path}")
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied accessing {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error editing file: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
3
agent/tools/file_save/__init__.py
Normal file
3
agent/tools/file_save/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .file_save import FileSave
|
||||
|
||||
__all__ = ['FileSave']
|
||||
770
agent/tools/file_save/file_save.py
Normal file
770
agent/tools/file_save/file_save.py
Normal file
@@ -0,0 +1,770 @@
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult, ToolStage
|
||||
from agent.models import LLMRequest
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class FileSave(BaseTool):
|
||||
"""Tool for saving content to files in the workspace directory."""
|
||||
|
||||
name = "file_save"
|
||||
description = "Save the agent's output to a file in the workspace directory. Content is automatically extracted from the agent's previous outputs."
|
||||
|
||||
# Set as post-process stage tool
|
||||
stage = ToolStage.POST_PROCESS
|
||||
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_name": {
|
||||
"type": "string",
|
||||
"description": "Optional. The name of the file to save. If not provided, a name will be generated based on the content."
|
||||
},
|
||||
"file_type": {
|
||||
"type": "string",
|
||||
"description": "Optional. The type/extension of the file (e.g., 'txt', 'md', 'py', 'java'). If not provided, it will be inferred from the content."
|
||||
},
|
||||
"extract_code": {
|
||||
"type": "boolean",
|
||||
"description": "Optional. If true, will attempt to extract code blocks from the content. Default is false."
|
||||
}
|
||||
},
|
||||
"required": [] # No required fields, as everything can be extracted from context
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.context = None
|
||||
self.config = {}
|
||||
self.workspace_dir = Path("workspace")
|
||||
|
||||
def execute(self, params: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Save content to a file in the workspace directory.
|
||||
|
||||
:param params: The parameters for the file output operation.
|
||||
:return: Result of the operation.
|
||||
"""
|
||||
# Extract content from context
|
||||
if not hasattr(self, 'context') or not self.context:
|
||||
return ToolResult.fail("Error: No context available to extract content from.")
|
||||
|
||||
content = self._extract_content_from_context()
|
||||
|
||||
# If no content could be extracted, return error
|
||||
if not content:
|
||||
return ToolResult.fail("Error: Couldn't extract content from context.")
|
||||
|
||||
# Use model to determine file parameters
|
||||
try:
|
||||
task_dir = self._get_task_dir_from_context()
|
||||
file_name, file_type, extract_code = self._get_file_params_from_model(content)
|
||||
except Exception as e:
|
||||
logger.error(f"Error determining file parameters: {str(e)}")
|
||||
# Fall back to manual parameter extraction
|
||||
task_dir = params.get("task_dir") or self._get_task_id_from_context() or f"task_{int(time.time())}"
|
||||
file_name = params.get("file_name") or self._infer_file_name(content)
|
||||
file_type = params.get("file_type") or self._infer_file_type(content)
|
||||
extract_code = params.get("extract_code", False)
|
||||
|
||||
# Get team_name from context
|
||||
team_name = self._get_team_name_from_context() or "default_team"
|
||||
|
||||
# Create directory structure
|
||||
task_dir_path = self.workspace_dir / team_name / task_dir
|
||||
task_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if extract_code:
|
||||
# Save the complete content as markdown
|
||||
md_file_name = f"{file_name}.md"
|
||||
md_file_path = task_dir_path / md_file_name
|
||||
|
||||
# Write content to file
|
||||
with open(md_file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
return self._handle_multiple_code_blocks(content)
|
||||
|
||||
# Ensure file_name has the correct extension
|
||||
if file_type and not file_name.endswith(f".{file_type}"):
|
||||
file_name = f"{file_name}.{file_type}"
|
||||
|
||||
# Create the full file path
|
||||
file_path = task_dir_path / file_name
|
||||
|
||||
# Get absolute path for storage in team_context
|
||||
abs_file_path = file_path.absolute()
|
||||
|
||||
try:
|
||||
# Write content to file
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
# Update the current agent's final_answer to include file information
|
||||
if hasattr(self.context, 'team_context'):
|
||||
# Store with absolute path in team_context
|
||||
self.context.team_context.agent_outputs[-1].output += f"\n\nSaved file: {abs_file_path}"
|
||||
|
||||
return ToolResult.success({
|
||||
"status": "success",
|
||||
"file_path": str(file_path) # Return relative path in result
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error saving file: {str(e)}")
|
||||
|
||||
def _handle_multiple_code_blocks(self, content: str) -> ToolResult:
|
||||
"""
|
||||
Handle content with multiple code blocks, extracting and saving each as a separate file.
|
||||
|
||||
:param content: The content containing multiple code blocks
|
||||
:return: Result of the operation
|
||||
"""
|
||||
# Extract code blocks with context (including potential file name information)
|
||||
code_blocks_with_context = self._extract_code_blocks_with_context(content)
|
||||
|
||||
if not code_blocks_with_context:
|
||||
return ToolResult.fail("No code blocks found in the content.")
|
||||
|
||||
# Get task directory and team name
|
||||
task_dir = self._get_task_dir_from_context() or f"task_{int(time.time())}"
|
||||
team_name = self._get_team_name_from_context() or "default_team"
|
||||
|
||||
# Create directory structure
|
||||
task_dir_path = self.workspace_dir / team_name / task_dir
|
||||
task_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
saved_files = []
|
||||
|
||||
for block_with_context in code_blocks_with_context:
|
||||
try:
|
||||
# Use model to determine file name for this code block
|
||||
block_file_name, block_file_type = self._get_filename_for_code_block(block_with_context)
|
||||
|
||||
# Clean the code block (remove md code markers)
|
||||
clean_code = self._clean_code_block(block_with_context)
|
||||
|
||||
# Ensure file_name has the correct extension
|
||||
if block_file_type and not block_file_name.endswith(f".{block_file_type}"):
|
||||
block_file_name = f"{block_file_name}.{block_file_type}"
|
||||
|
||||
# Create the full file path (no subdirectories)
|
||||
file_path = task_dir_path / block_file_name
|
||||
|
||||
# Get absolute path for storage in team_context
|
||||
abs_file_path = file_path.absolute()
|
||||
|
||||
# Write content to file
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(clean_code)
|
||||
|
||||
saved_files.append({
|
||||
"file_path": str(file_path),
|
||||
"abs_file_path": str(abs_file_path), # Store absolute path for internal use
|
||||
"file_name": block_file_name,
|
||||
"size": len(clean_code),
|
||||
"status": "success",
|
||||
"type": "code"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving code block: {str(e)}")
|
||||
# Continue with the next block even if this one fails
|
||||
|
||||
if not saved_files:
|
||||
return ToolResult.fail("Failed to save any code blocks.")
|
||||
|
||||
# Update the current agent's final_answer to include files information
|
||||
if hasattr(self, 'context') and self.context:
|
||||
# If the agent has a final_answer attribute, append the files info to it
|
||||
if hasattr(self.context, 'team_context'):
|
||||
# Use relative paths for display
|
||||
display_info = f"\n\nSaved files to {task_dir_path}:\n" + "\n".join(
|
||||
[f"- {f['file_path']}" for f in saved_files])
|
||||
|
||||
# Check if we need to append the info
|
||||
if not self.context.team_context.agent_outputs[-1].output.endswith(display_info):
|
||||
# Store with absolute paths in team_context
|
||||
abs_info = f"\n\nSaved files to {task_dir_path.absolute()}:\n" + "\n".join(
|
||||
[f"- {f['abs_file_path']}" for f in saved_files])
|
||||
self.context.team_context.agent_outputs[-1].output += abs_info
|
||||
|
||||
result = {
|
||||
"status": "success",
|
||||
"files": [{"file_path": f["file_path"]} for f in saved_files]
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _extract_code_blocks_with_context(self, content: str) -> list:
|
||||
"""
|
||||
Extract code blocks from content, including context lines before the block.
|
||||
|
||||
:param content: The content to extract code blocks from
|
||||
:return: List of code blocks with context
|
||||
"""
|
||||
# Check if content starts with <!DOCTYPE or <html - likely a full HTML file
|
||||
if content.strip().startswith(("<!DOCTYPE", "<html", "<?xml")):
|
||||
return [content] # Return the entire content as a single block
|
||||
|
||||
# Split content into lines
|
||||
lines = content.split('\n')
|
||||
|
||||
blocks = []
|
||||
in_code_block = False
|
||||
current_block = []
|
||||
context_lines = []
|
||||
|
||||
# Check if there are any code block markers in the content
|
||||
if not re.search(r'```\w+', content):
|
||||
# If no code block markers and content looks like code, return the entire content
|
||||
if self._is_likely_code(content):
|
||||
return [content]
|
||||
|
||||
for line in lines:
|
||||
if line.strip().startswith('```'):
|
||||
if in_code_block:
|
||||
# End of code block
|
||||
current_block.append(line)
|
||||
# Only add blocks that have a language specified
|
||||
block_content = '\n'.join(current_block)
|
||||
if re.search(r'```\w+', current_block[0]):
|
||||
# Combine context with code block
|
||||
blocks.append('\n'.join(context_lines + current_block))
|
||||
current_block = []
|
||||
context_lines = []
|
||||
in_code_block = False
|
||||
else:
|
||||
# Start of code block - check if it has a language specified
|
||||
if re.search(r'```\w+', line) and not re.search(r'```language=\s*$', line):
|
||||
# Start of code block with language
|
||||
in_code_block = True
|
||||
current_block = [line]
|
||||
# Keep only the last few context lines
|
||||
context_lines = context_lines[-5:] if context_lines else []
|
||||
|
||||
elif in_code_block:
|
||||
current_block.append(line)
|
||||
else:
|
||||
# Store context lines when not in a code block
|
||||
context_lines.append(line)
|
||||
|
||||
return blocks
|
||||
|
||||
def _get_filename_for_code_block(self, block_with_context: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Determine the file name for a code block.
|
||||
|
||||
:param block_with_context: The code block with context lines
|
||||
:return: Tuple of (file_name, file_type)
|
||||
"""
|
||||
# Define common code file extensions
|
||||
COMMON_CODE_EXTENSIONS = {
|
||||
'py', 'js', 'java', 'c', 'cpp', 'h', 'hpp', 'cs', 'go', 'rb', 'php',
|
||||
'html', 'css', 'ts', 'jsx', 'tsx', 'vue', 'sh', 'sql', 'json', 'xml',
|
||||
'yaml', 'yml', 'md', 'rs', 'swift', 'kt', 'scala', 'pl', 'r', 'lua'
|
||||
}
|
||||
|
||||
# Split the block into lines to examine only the context around code block markers
|
||||
lines = block_with_context.split('\n')
|
||||
|
||||
# Find the code block start marker line index
|
||||
start_marker_idx = -1
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith('```') and not line.strip() == '```':
|
||||
start_marker_idx = i
|
||||
break
|
||||
|
||||
if start_marker_idx == -1:
|
||||
# No code block marker found
|
||||
return "", ""
|
||||
|
||||
# Extract the language from the code block marker
|
||||
code_marker = lines[start_marker_idx].strip()
|
||||
language = ""
|
||||
if len(code_marker) > 3:
|
||||
language = code_marker[3:].strip().split('=')[0].strip()
|
||||
|
||||
# Define the context range (5 lines before and 2 after the marker)
|
||||
context_start = max(0, start_marker_idx - 5)
|
||||
context_end = min(len(lines), start_marker_idx + 3)
|
||||
|
||||
# Extract only the relevant context lines
|
||||
context_lines = lines[context_start:context_end]
|
||||
|
||||
# First, check for explicit file headers like "## filename.ext"
|
||||
for line in context_lines:
|
||||
# Match patterns like "## filename.ext" or "# filename.ext"
|
||||
header_match = re.search(r'^\s*#{1,6}\s+([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)\s*$', line)
|
||||
if header_match:
|
||||
file_name = header_match.group(1)
|
||||
file_type = os.path.splitext(file_name)[1].lstrip('.')
|
||||
if file_type in COMMON_CODE_EXTENSIONS:
|
||||
return os.path.splitext(file_name)[0], file_type
|
||||
|
||||
# Simple patterns to match explicit file names in the context
|
||||
file_patterns = [
|
||||
# Match explicit file names in headers or text
|
||||
r'(?:file|filename)[:=\s]+[\'"]?([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)[\'"]?',
|
||||
# Match language=filename.ext in code markers
|
||||
r'language=([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)',
|
||||
# Match standalone filenames with extensions
|
||||
r'\b([a-zA-Z0-9_-]+\.(py|js|java|c|cpp|h|hpp|cs|go|rb|php|html|css|ts|jsx|tsx|vue|sh|sql|json|xml|yaml|yml|md|rs|swift|kt|scala|pl|r|lua))\b',
|
||||
# Match file paths in comments
|
||||
r'#\s*([a-zA-Z0-9_/-]+\.[a-zA-Z0-9]+)'
|
||||
]
|
||||
|
||||
# Check each context line for file name patterns
|
||||
for line in context_lines:
|
||||
line = line.strip()
|
||||
for pattern in file_patterns:
|
||||
matches = re.findall(pattern, line)
|
||||
if matches:
|
||||
for match in matches:
|
||||
if isinstance(match, tuple):
|
||||
# If the match is a tuple (filename, extension)
|
||||
file_name = match[0]
|
||||
file_type = match[1]
|
||||
# Verify it's not a code reference like Direction.DOWN
|
||||
if not any(keyword in file_name for keyword in ['class.', 'enum.', 'import.']):
|
||||
return os.path.splitext(file_name)[0], file_type
|
||||
else:
|
||||
# If the match is a string (full filename)
|
||||
file_name = match
|
||||
file_type = os.path.splitext(file_name)[1].lstrip('.')
|
||||
# Verify it's not a code reference
|
||||
if file_type in COMMON_CODE_EXTENSIONS and not any(
|
||||
keyword in file_name for keyword in ['class.', 'enum.', 'import.']):
|
||||
return os.path.splitext(file_name)[0], file_type
|
||||
|
||||
# If no explicit file name found, use LLM to infer from code content
|
||||
# Extract the code content
|
||||
code_content = block_with_context
|
||||
|
||||
# Get the first 20 lines of code for LLM analysis
|
||||
code_lines = code_content.split('\n')
|
||||
code_preview = '\n'.join(code_lines[:20])
|
||||
|
||||
# Get the model to use
|
||||
model_to_use = None
|
||||
if hasattr(self, 'context') and self.context:
|
||||
if hasattr(self.context, 'model') and self.context.model:
|
||||
model_to_use = self.context.model
|
||||
elif hasattr(self.context, 'team_context') and self.context.team_context:
|
||||
if hasattr(self.context.team_context, 'model') and self.context.team_context.model:
|
||||
model_to_use = self.context.team_context.model
|
||||
|
||||
# If no model is available in context, use the tool's model
|
||||
if not model_to_use and hasattr(self, 'model') and self.model:
|
||||
model_to_use = self.model
|
||||
|
||||
if model_to_use:
|
||||
# Prepare a prompt for the model
|
||||
prompt = f"""Analyze the following code and determine the most appropriate file name and file type/extension.
|
||||
The file name should be descriptive but concise, using snake_case (lowercase with underscores).
|
||||
The file type should be a standard file extension (e.g., py, js, html, css, java).
|
||||
|
||||
Code preview (first 20 lines):
|
||||
{code_preview}
|
||||
|
||||
Return your answer in JSON format with these fields:
|
||||
- file_name: The suggested file name (without extension)
|
||||
- file_type: The suggested file extension
|
||||
|
||||
JSON response:"""
|
||||
|
||||
# Create a request to the model
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0,
|
||||
json_format=True
|
||||
)
|
||||
|
||||
try:
|
||||
response = model_to_use.call(request)
|
||||
|
||||
if not response.is_error:
|
||||
# Clean the JSON response
|
||||
json_content = self._clean_json_response(response.data["choices"][0]["message"]["content"])
|
||||
result = json.loads(json_content)
|
||||
|
||||
file_name = result.get("file_name", "")
|
||||
file_type = result.get("file_type", "")
|
||||
|
||||
if file_name and file_type:
|
||||
return file_name, file_type
|
||||
except Exception as e:
|
||||
logger.error(f"Error using model to determine file name: {str(e)}")
|
||||
|
||||
# If we still don't have a file name, use the language as file type
|
||||
if language and language in COMMON_CODE_EXTENSIONS:
|
||||
timestamp = int(time.time())
|
||||
return f"code_{timestamp}", language
|
||||
|
||||
# If all else fails, return empty strings
|
||||
return "", ""
|
||||
|
||||
def _clean_json_response(self, text: str) -> str:
|
||||
"""
|
||||
Clean JSON response from LLM by removing markdown code block markers.
|
||||
|
||||
:param text: The text containing JSON possibly wrapped in markdown code blocks
|
||||
:return: Clean JSON string
|
||||
"""
|
||||
# Remove markdown code block markers if present
|
||||
if text.startswith("```json"):
|
||||
text = text[7:]
|
||||
elif text.startswith("```"):
|
||||
# Find the first newline to skip the language identifier line
|
||||
first_newline = text.find('\n')
|
||||
if first_newline != -1:
|
||||
text = text[first_newline + 1:]
|
||||
|
||||
if text.endswith("```"):
|
||||
text = text[:-3]
|
||||
|
||||
return text.strip()
|
||||
|
||||
def _clean_code_block(self, block_with_context: str) -> str:
|
||||
"""
|
||||
Clean a code block by removing markdown code markers and context lines.
|
||||
|
||||
:param block_with_context: Code block with context lines
|
||||
:return: Clean code ready for execution
|
||||
"""
|
||||
# Check if this is a full HTML or XML document
|
||||
if block_with_context.strip().startswith(("<!DOCTYPE", "<html", "<?xml")):
|
||||
return block_with_context
|
||||
|
||||
# Find the code block
|
||||
code_block_match = re.search(r'```(?:\w+)?(?:[:=][^\n]+)?\n([\s\S]*?)\n```', block_with_context)
|
||||
|
||||
if code_block_match:
|
||||
return code_block_match.group(1)
|
||||
|
||||
# If no match found, try to extract anything between ``` markers
|
||||
lines = block_with_context.split('\n')
|
||||
start_idx = None
|
||||
end_idx = None
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith('```'):
|
||||
if start_idx is None:
|
||||
start_idx = i
|
||||
else:
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
if start_idx is not None and end_idx is not None:
|
||||
# Extract the code between the markers, excluding the markers themselves
|
||||
code_lines = lines[start_idx + 1:end_idx]
|
||||
return '\n'.join(code_lines)
|
||||
|
||||
# If all else fails, return the original content
|
||||
return block_with_context
|
||||
|
||||
def _get_file_params_from_model(self, content, model=None):
|
||||
"""
|
||||
Use LLM to determine if the content is code and suggest appropriate file parameters.
|
||||
|
||||
Args:
|
||||
content: The content to analyze
|
||||
model: Optional model to use for the analysis
|
||||
|
||||
Returns:
|
||||
tuple: (file_name, file_type, extract_code) for backward compatibility
|
||||
"""
|
||||
if model is None:
|
||||
model = self.model
|
||||
|
||||
if not model:
|
||||
# Default fallback if no model is available
|
||||
return "output", "txt", False
|
||||
|
||||
prompt = f"""
|
||||
Analyze the following content and determine:
|
||||
1. Is this primarily code implementation (where most of the content consists of code blocks)?
|
||||
2. What would be an appropriate filename and file extension?
|
||||
|
||||
Content to analyze: ```
|
||||
{content[:500]} # Only show first 500 chars to avoid token limits ```
|
||||
|
||||
{"..." if len(content) > 500 else ""}
|
||||
|
||||
Respond in JSON format only with the following structure:
|
||||
{{
|
||||
"is_code": true/false, # Whether this is primarily code implementation
|
||||
"filename": "suggested_filename", # Don't include extension, english words
|
||||
"extension": "appropriate_extension" # Don't include the dot, e.g., "md", "py", "js"
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
# Create a request to the model
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.1,
|
||||
json_format=True
|
||||
)
|
||||
|
||||
# Call the model using the standard interface
|
||||
response = model.call(request)
|
||||
|
||||
if response.is_error:
|
||||
logger.warning(f"Error from model: {response.error_message}")
|
||||
raise Exception(f"Model error: {response.error_message}")
|
||||
|
||||
# Extract JSON from response
|
||||
result = response.data["choices"][0]["message"]["content"]
|
||||
|
||||
# Clean the JSON response
|
||||
result = self._clean_json_response(result)
|
||||
|
||||
# Parse the JSON
|
||||
params = json.loads(result)
|
||||
|
||||
# For backward compatibility, return tuple format
|
||||
file_name = params.get("filename", "output")
|
||||
# Remove dot from extension if present
|
||||
file_type = params.get("extension", "md").lstrip(".")
|
||||
extract_code = params.get("is_code", False)
|
||||
|
||||
return file_name, file_type, extract_code
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting file parameters from model: {e}")
|
||||
# Default fallback
|
||||
return "output", "md", False
|
||||
|
||||
def _get_team_name_from_context(self) -> Optional[str]:
|
||||
"""
|
||||
Get team name from the agent's context.
|
||||
|
||||
:return: Team name or None if not found
|
||||
"""
|
||||
if hasattr(self, 'context') and self.context:
|
||||
# Try to get team name from team_context
|
||||
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||
return self.context.team_context.name
|
||||
|
||||
# Try direct team_name attribute
|
||||
if hasattr(self.context, 'name'):
|
||||
return self.context.name
|
||||
|
||||
return None
|
||||
|
||||
def _get_task_id_from_context(self) -> Optional[str]:
|
||||
"""
|
||||
Get task ID from the agent's context.
|
||||
|
||||
:return: Task ID or None if not found
|
||||
"""
|
||||
if hasattr(self, 'context') and self.context:
|
||||
# Try to get task ID from task object
|
||||
if hasattr(self.context, 'task') and self.context.task:
|
||||
return self.context.task.id
|
||||
|
||||
# Try team_context's task
|
||||
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||
if hasattr(self.context.team_context, 'task') and self.context.team_context.task:
|
||||
return self.context.team_context.task.id
|
||||
|
||||
return None
|
||||
|
||||
def _get_task_dir_from_context(self) -> Optional[str]:
|
||||
"""
|
||||
Get task directory name from the team context.
|
||||
|
||||
:return: Task directory name or None if not found
|
||||
"""
|
||||
if hasattr(self, 'context') and self.context:
|
||||
# Try to get from team_context
|
||||
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||
if hasattr(self.context.team_context, 'task_short_name') and self.context.team_context.task_short_name:
|
||||
return self.context.team_context.task_short_name
|
||||
|
||||
# Fall back to task ID if available
|
||||
return self._get_task_id_from_context()
|
||||
|
||||
def _extract_content_from_context(self) -> str:
|
||||
"""
|
||||
Extract content from the agent's context.
|
||||
|
||||
:return: Extracted content
|
||||
"""
|
||||
# Check if we have access to the agent's context
|
||||
if not hasattr(self, 'context') or not self.context:
|
||||
return ""
|
||||
|
||||
# Try to get the most recent final answer from the agent
|
||||
if hasattr(self.context, 'final_answer') and self.context.final_answer:
|
||||
return self.context.final_answer
|
||||
|
||||
# Try to get the most recent final answer from team context
|
||||
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||
if hasattr(self.context.team_context, 'agent_outputs') and self.context.team_context.agent_outputs:
|
||||
latest_output = self.context.team_context.agent_outputs[-1].output
|
||||
return latest_output
|
||||
|
||||
# If we have action history, try to get the most recent final answer
|
||||
if hasattr(self.context, 'action_history') and self.context.action_history:
|
||||
for action in reversed(self.context.action_history):
|
||||
if "final_answer" in action and action["final_answer"]:
|
||||
return action["final_answer"]
|
||||
|
||||
return ""
|
||||
|
||||
def _extract_code_blocks(self, content: str) -> str:
|
||||
"""
|
||||
Extract code blocks from markdown content.
|
||||
|
||||
:param content: The content to extract code blocks from
|
||||
:return: Extracted code blocks
|
||||
"""
|
||||
# Pattern to match markdown code blocks
|
||||
code_block_pattern = r'```(?:\w+)?\n([\s\S]*?)\n```'
|
||||
|
||||
# Find all code blocks
|
||||
code_blocks = re.findall(code_block_pattern, content)
|
||||
|
||||
if code_blocks:
|
||||
# Join all code blocks with newlines
|
||||
return '\n\n'.join(code_blocks)
|
||||
|
||||
return content # Return original content if no code blocks found
|
||||
|
||||
def _infer_file_name(self, content: str) -> str:
|
||||
"""
|
||||
Infer a file name from the content.
|
||||
|
||||
:param content: The content to analyze.
|
||||
:return: A suggested file name.
|
||||
"""
|
||||
# Check for title patterns in markdown
|
||||
title_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
|
||||
if title_match:
|
||||
# Convert title to a valid filename
|
||||
title = title_match.group(1).strip()
|
||||
return self._sanitize_filename(title)
|
||||
|
||||
# Check for class/function definitions in code
|
||||
code_match = re.search(r'(class|def|function)\s+(\w+)', content)
|
||||
if code_match:
|
||||
return self._sanitize_filename(code_match.group(2))
|
||||
|
||||
# Default name based on content type
|
||||
if self._is_likely_code(content):
|
||||
return "code"
|
||||
elif self._is_likely_markdown(content):
|
||||
return "document"
|
||||
elif self._is_likely_json(content):
|
||||
return "data"
|
||||
else:
|
||||
return "output"
|
||||
|
||||
def _infer_file_type(self, content: str) -> str:
|
||||
"""
|
||||
Infer the file type/extension from the content.
|
||||
|
||||
:param content: The content to analyze.
|
||||
:return: A suggested file extension.
|
||||
"""
|
||||
# Check for common programming language patterns
|
||||
if re.search(r'(import\s+[a-zA-Z0-9_]+|from\s+[a-zA-Z0-9_\.]+\s+import)', content):
|
||||
return "py" # Python
|
||||
elif re.search(r'(public\s+class|private\s+class|protected\s+class)', content):
|
||||
return "java" # Java
|
||||
elif re.search(r'(function\s+\w+\s*\(|const\s+\w+\s*=|let\s+\w+\s*=|var\s+\w+\s*=)', content):
|
||||
return "js" # JavaScript
|
||||
elif re.search(r'(<html|<body|<div|<p>)', content):
|
||||
return "html" # HTML
|
||||
elif re.search(r'(#include\s+<\w+\.h>|int\s+main\s*\()', content):
|
||||
return "cpp" # C/C++
|
||||
|
||||
# Check for markdown
|
||||
if self._is_likely_markdown(content):
|
||||
return "md"
|
||||
|
||||
# Check for JSON
|
||||
if self._is_likely_json(content):
|
||||
return "json"
|
||||
|
||||
# Default to text
|
||||
return "txt"
|
||||
|
||||
def _is_likely_code(self, content: str) -> bool:
|
||||
"""Check if the content is likely code."""
|
||||
# First check for common HTML/XML patterns
|
||||
if content.strip().startswith(("<!DOCTYPE", "<html", "<?xml", "<head", "<body")):
|
||||
return True
|
||||
|
||||
code_patterns = [
|
||||
r'(class|def|function|import|from|public|private|protected|#include)',
|
||||
r'(\{\s*\n|\}\s*\n|\[\s*\n|\]\s*\n)',
|
||||
r'(if\s*\(|for\s*\(|while\s*\()',
|
||||
r'(<\w+>.*?</\w+>)', # HTML/XML tags
|
||||
r'(var|let|const)\s+\w+\s*=', # JavaScript variable declarations
|
||||
r'#\s*\w+', # CSS ID selectors or Python comments
|
||||
r'\.\w+\s*\{', # CSS class selectors
|
||||
r'@media|@import|@font-face' # CSS at-rules
|
||||
]
|
||||
return any(re.search(pattern, content) for pattern in code_patterns)
|
||||
|
||||
def _is_likely_markdown(self, content: str) -> bool:
|
||||
"""Check if the content is likely markdown."""
|
||||
md_patterns = [
|
||||
r'^#\s+.+$', # Headers
|
||||
r'^\*\s+.+$', # Unordered lists
|
||||
r'^\d+\.\s+.+$', # Ordered lists
|
||||
r'\[.+\]\(.+\)', # Links
|
||||
r'!\[.+\]\(.+\)' # Images
|
||||
]
|
||||
return any(re.search(pattern, content, re.MULTILINE) for pattern in md_patterns)
|
||||
|
||||
def _is_likely_json(self, content: str) -> bool:
|
||||
"""Check if the content is likely JSON."""
|
||||
try:
|
||||
content = content.strip()
|
||||
if (content.startswith('{') and content.endswith('}')) or (
|
||||
content.startswith('[') and content.endswith(']')):
|
||||
json.loads(content)
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _sanitize_filename(self, name: str) -> str:
|
||||
"""
|
||||
Sanitize a string to be used as a filename.
|
||||
|
||||
:param name: The string to sanitize.
|
||||
:return: A sanitized filename.
|
||||
"""
|
||||
# Replace spaces with underscores
|
||||
name = name.replace(' ', '_')
|
||||
|
||||
# Remove invalid characters
|
||||
name = re.sub(r'[^\w\-\.]', '', name)
|
||||
|
||||
# Limit length
|
||||
if len(name) > 50:
|
||||
name = name[:50]
|
||||
|
||||
return name.lower()
|
||||
|
||||
def _process_file_path(self, file_path: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Process a file path to extract the file name and type, and create directories if needed.
|
||||
|
||||
:param file_path: The file path to process
|
||||
:return: Tuple of (file_name, file_type)
|
||||
"""
|
||||
# Get the file name and extension
|
||||
file_name = os.path.basename(file_path)
|
||||
file_type = os.path.splitext(file_name)[1].lstrip('.')
|
||||
|
||||
return os.path.splitext(file_name)[0], file_type
|
||||
3
agent/tools/find/__init__.py
Normal file
3
agent/tools/find/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .find import Find
|
||||
|
||||
__all__ = ['Find']
|
||||
177
agent/tools/find/find.py
Normal file
177
agent/tools/find/find.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Find tool - Search for files by glob pattern
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob as glob_module
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 1000
|
||||
|
||||
|
||||
class Find(BaseTool):
|
||||
"""Tool for finding files by pattern"""
|
||||
|
||||
name: str = "find"
|
||||
description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to search in (default: current directory)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of results (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file search
|
||||
|
||||
:param args: Search parameters
|
||||
:return: Search results or error
|
||||
"""
|
||||
pattern = args.get("pattern", "").strip()
|
||||
search_path = args.get("path", ".").strip()
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
if not pattern:
|
||||
return ToolResult.fail("Error: pattern parameter is required")
|
||||
|
||||
# Resolve search path
|
||||
absolute_path = self._resolve_path(search_path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||
|
||||
if not os.path.isdir(absolute_path):
|
||||
return ToolResult.fail(f"Error: Not a directory: {search_path}")
|
||||
|
||||
try:
|
||||
# Load .gitignore patterns
|
||||
ignore_patterns = self._load_gitignore(absolute_path)
|
||||
|
||||
# Search for files
|
||||
results = []
|
||||
search_pattern = os.path.join(absolute_path, pattern)
|
||||
|
||||
# Use glob with recursive support
|
||||
for file_path in glob_module.glob(search_pattern, recursive=True):
|
||||
# Skip if matches ignore patterns
|
||||
if self._should_ignore(file_path, absolute_path, ignore_patterns):
|
||||
continue
|
||||
|
||||
# Get relative path
|
||||
relative_path = os.path.relpath(file_path, absolute_path)
|
||||
|
||||
# Add trailing slash for directories
|
||||
if os.path.isdir(file_path):
|
||||
relative_path += '/'
|
||||
|
||||
results.append(relative_path)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
if not results:
|
||||
return ToolResult.success({"message": "No files found matching pattern", "files": []})
|
||||
|
||||
# Sort results
|
||||
results.sort()
|
||||
|
||||
# Format output
|
||||
raw_output = '\n'.join(results)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
result_limit_reached = len(results) >= limit
|
||||
if result_limit_reached:
|
||||
notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||
details["result_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"file_count": len(results),
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing find: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _load_gitignore(self, directory: str) -> List[str]:
|
||||
"""Load .gitignore patterns from directory"""
|
||||
patterns = []
|
||||
gitignore_path = os.path.join(directory, '.gitignore')
|
||||
|
||||
if os.path.exists(gitignore_path):
|
||||
try:
|
||||
with open(gitignore_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
patterns.append(line)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Add common ignore patterns
|
||||
patterns.extend([
|
||||
'.git',
|
||||
'__pycache__',
|
||||
'*.pyc',
|
||||
'node_modules',
|
||||
'.DS_Store'
|
||||
])
|
||||
|
||||
return patterns
|
||||
|
||||
def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool:
|
||||
"""Check if file should be ignored based on patterns"""
|
||||
relative_path = os.path.relpath(file_path, base_path)
|
||||
|
||||
for pattern in patterns:
|
||||
# Simple pattern matching
|
||||
if pattern in relative_path:
|
||||
return True
|
||||
|
||||
# Check if it's a directory pattern
|
||||
if pattern.endswith('/'):
|
||||
if relative_path.startswith(pattern.rstrip('/')):
|
||||
return True
|
||||
|
||||
return False
|
||||
3
agent/tools/grep/__init__.py
Normal file
3
agent/tools/grep/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .grep import Grep
|
||||
|
||||
__all__ = ['Grep']
|
||||
248
agent/tools/grep/grep.py
Normal file
248
agent/tools/grep/grep.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Grep tool - Search file contents for patterns
|
||||
Uses ripgrep (rg) for fast searching
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import (
|
||||
truncate_head, truncate_line, format_size,
|
||||
DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 100
|
||||
|
||||
|
||||
class Grep(BaseTool):
|
||||
"""Tool for searching file contents"""
|
||||
|
||||
name: str = "grep"
|
||||
description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Search pattern (regex or literal string)"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search (default: current directory)"
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'"
|
||||
},
|
||||
"ignoreCase": {
|
||||
"type": "boolean",
|
||||
"description": "Case-insensitive search (default: false)"
|
||||
},
|
||||
"literal": {
|
||||
"type": "boolean",
|
||||
"description": "Treat pattern as literal string instead of regex (default: false)"
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to show before and after each match (default: 0)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.rg_path = self._find_ripgrep()
|
||||
|
||||
def _find_ripgrep(self) -> Optional[str]:
|
||||
"""Find ripgrep executable"""
|
||||
try:
|
||||
result = subprocess.run(['which', 'rg'], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute grep search
|
||||
|
||||
:param args: Search parameters
|
||||
:return: Search results or error
|
||||
"""
|
||||
if not self.rg_path:
|
||||
return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.")
|
||||
|
||||
pattern = args.get("pattern", "").strip()
|
||||
search_path = args.get("path", ".").strip()
|
||||
glob = args.get("glob")
|
||||
ignore_case = args.get("ignoreCase", False)
|
||||
literal = args.get("literal", False)
|
||||
context = args.get("context", 0)
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
if not pattern:
|
||||
return ToolResult.fail("Error: pattern parameter is required")
|
||||
|
||||
# Resolve search path
|
||||
absolute_path = self._resolve_path(search_path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||
|
||||
# Build ripgrep command
|
||||
cmd = [
|
||||
self.rg_path,
|
||||
'--json',
|
||||
'--line-number',
|
||||
'--color=never',
|
||||
'--hidden'
|
||||
]
|
||||
|
||||
if ignore_case:
|
||||
cmd.append('--ignore-case')
|
||||
|
||||
if literal:
|
||||
cmd.append('--fixed-strings')
|
||||
|
||||
if glob:
|
||||
cmd.extend(['--glob', glob])
|
||||
|
||||
cmd.extend([pattern, absolute_path])
|
||||
|
||||
try:
|
||||
# Execute ripgrep
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=self.cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Parse JSON output
|
||||
matches = []
|
||||
match_count = 0
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
event = json.loads(line)
|
||||
if event.get('type') == 'match':
|
||||
data = event.get('data', {})
|
||||
file_path = data.get('path', {}).get('text')
|
||||
line_number = data.get('line_number')
|
||||
|
||||
if file_path and line_number:
|
||||
matches.append({
|
||||
'file': file_path,
|
||||
'line': line_number
|
||||
})
|
||||
match_count += 1
|
||||
|
||||
if match_count >= limit:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if match_count == 0:
|
||||
return ToolResult.success({"message": "No matches found", "matches": []})
|
||||
|
||||
# Format output with context
|
||||
output_lines = []
|
||||
lines_truncated = False
|
||||
is_directory = os.path.isdir(absolute_path)
|
||||
|
||||
for match in matches:
|
||||
file_path = match['file']
|
||||
line_number = match['line']
|
||||
|
||||
# Format file path
|
||||
if is_directory:
|
||||
relative_path = os.path.relpath(file_path, absolute_path)
|
||||
else:
|
||||
relative_path = os.path.basename(file_path)
|
||||
|
||||
# Read file and get context
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
file_lines = f.read().split('\n')
|
||||
|
||||
# Calculate context range
|
||||
start = max(0, line_number - 1 - context) if context > 0 else line_number - 1
|
||||
end = min(len(file_lines), line_number + context) if context > 0 else line_number
|
||||
|
||||
# Format lines with context
|
||||
for i in range(start, end):
|
||||
line_text = file_lines[i].replace('\r', '')
|
||||
|
||||
# Truncate long lines
|
||||
truncated_text, was_truncated = truncate_line(line_text)
|
||||
if was_truncated:
|
||||
lines_truncated = True
|
||||
|
||||
# Format output
|
||||
current_line = i + 1
|
||||
if current_line == line_number:
|
||||
output_lines.append(f"{relative_path}:{current_line}: {truncated_text}")
|
||||
else:
|
||||
output_lines.append(f"{relative_path}-{current_line}- {truncated_text}")
|
||||
|
||||
except Exception:
|
||||
output_lines.append(f"{relative_path}:{line_number}: (unable to read file)")
|
||||
|
||||
# Apply byte truncation
|
||||
raw_output = '\n'.join(output_lines)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
if match_count >= limit:
|
||||
notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||
details["match_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if lines_truncated:
|
||||
notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines")
|
||||
details["lines_truncated"] = True
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"match_count": match_count,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResult.fail("Error: Search timed out after 30 seconds")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing grep: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
3
agent/tools/ls/__init__.py
Normal file
3
agent/tools/ls/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ls import Ls
|
||||
|
||||
__all__ = ['Ls']
|
||||
125
agent/tools/ls/ls.py
Normal file
125
agent/tools/ls/ls.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Ls tool - List directory contents
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 500
|
||||
|
||||
|
||||
class Ls(BaseTool):
|
||||
"""Tool for listing directory contents"""
|
||||
|
||||
name: str = "ls"
|
||||
description: str = f"List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to {DEFAULT_LIMIT} entries or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to list (default: current directory)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of entries to return (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute directory listing
|
||||
|
||||
:param args: Listing parameters
|
||||
:return: Directory contents or error
|
||||
"""
|
||||
path = args.get("path", ".").strip()
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {path}")
|
||||
|
||||
if not os.path.isdir(absolute_path):
|
||||
return ToolResult.fail(f"Error: Not a directory: {path}")
|
||||
|
||||
try:
|
||||
# Read directory entries
|
||||
entries = os.listdir(absolute_path)
|
||||
|
||||
# Sort alphabetically (case-insensitive)
|
||||
entries.sort(key=lambda x: x.lower())
|
||||
|
||||
# Format entries with directory indicators
|
||||
results = []
|
||||
entry_limit_reached = False
|
||||
|
||||
for entry in entries:
|
||||
if len(results) >= limit:
|
||||
entry_limit_reached = True
|
||||
break
|
||||
|
||||
full_path = os.path.join(absolute_path, entry)
|
||||
|
||||
try:
|
||||
if os.path.isdir(full_path):
|
||||
results.append(entry + '/')
|
||||
else:
|
||||
results.append(entry)
|
||||
except:
|
||||
# Skip entries we can't stat
|
||||
continue
|
||||
|
||||
if not results:
|
||||
return ToolResult.success({"message": "(empty directory)", "entries": []})
|
||||
|
||||
# Format output
|
||||
raw_output = '\n'.join(results)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
if entry_limit_reached:
|
||||
notices.append(f"{limit} entries limit reached. Use limit={limit * 2} for more")
|
||||
details["entry_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"entry_count": len(results),
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied reading directory: {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error listing directory: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
10
agent/tools/memory/__init__.py
Normal file
10
agent/tools/memory/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Memory tools for Agent
|
||||
|
||||
Provides memory_search and memory_get tools
|
||||
"""
|
||||
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
__all__ = ['MemorySearchTool', 'MemoryGetTool']
|
||||
112
agent/tools/memory/memory_get.py
Normal file
112
agent/tools/memory/memory_get.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Memory get tool
|
||||
|
||||
Allows agents to read specific sections from memory files
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class MemoryGetTool(BaseTool):
|
||||
"""Tool for reading memory file contents"""
|
||||
|
||||
name: str = "memory_get"
|
||||
description: str = (
|
||||
"Read specific content from memory files. "
|
||||
"Use this to get full context from a memory file or specific line range."
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')"
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"description": "Starting line number (optional, default: 1)",
|
||||
"default": 1
|
||||
},
|
||||
"num_lines": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read (optional, reads all if not specified)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, memory_manager):
|
||||
"""
|
||||
Initialize memory get tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
Execute memory file read
|
||||
|
||||
Args:
|
||||
args: Dictionary with path, start_line, num_lines
|
||||
|
||||
Returns:
|
||||
ToolResult with file content
|
||||
"""
|
||||
from agent.tools.base_tool import ToolResult
|
||||
|
||||
path = args.get("path")
|
||||
start_line = args.get("start_line", 1)
|
||||
num_lines = args.get("num_lines")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
try:
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
|
||||
# Auto-prepend memory/ if not present and not absolute path
|
||||
if not path.startswith('memory/') and not path.startswith('/'):
|
||||
path = f'memory/{path}'
|
||||
|
||||
file_path = workspace_dir / path
|
||||
|
||||
if not file_path.exists():
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
content = file_path.read_text()
|
||||
lines = content.split('\n')
|
||||
|
||||
# Handle line range
|
||||
if start_line < 1:
|
||||
start_line = 1
|
||||
|
||||
start_idx = start_line - 1
|
||||
|
||||
if num_lines:
|
||||
end_idx = start_idx + num_lines
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
else:
|
||||
selected_lines = lines[start_idx:]
|
||||
|
||||
result = '\n'.join(selected_lines)
|
||||
|
||||
# Add metadata
|
||||
total_lines = len(lines)
|
||||
shown_lines = len(selected_lines)
|
||||
|
||||
output = [
|
||||
f"File: {path}",
|
||||
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
|
||||
"",
|
||||
result
|
||||
]
|
||||
|
||||
return ToolResult.success('\n'.join(output))
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading memory file: {str(e)}")
|
||||
102
agent/tools/memory/memory_search.py
Normal file
102
agent/tools/memory/memory_search.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Memory search tool
|
||||
|
||||
Allows agents to search their memory using semantic and keyword search
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class MemorySearchTool(BaseTool):
|
||||
"""Tool for searching agent memory"""
|
||||
|
||||
name: str = "memory_search"
|
||||
description: str = (
|
||||
"Search agent's long-term memory using semantic and keyword search. "
|
||||
"Use this to recall past conversations, preferences, and knowledge."
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (can be natural language question or keywords)"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (default: 10)",
|
||||
"default": 10
|
||||
},
|
||||
"min_score": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score (0-1, default: 0.1)",
|
||||
"default": 0.1
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, memory_manager, user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize memory search tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
user_id: Optional user ID for scoped search
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
self.user_id = user_id
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
Execute memory search
|
||||
|
||||
Args:
|
||||
args: Dictionary with query, max_results, min_score
|
||||
|
||||
Returns:
|
||||
ToolResult with formatted search results
|
||||
"""
|
||||
from agent.tools.base_tool import ToolResult
|
||||
import asyncio
|
||||
|
||||
query = args.get("query")
|
||||
max_results = args.get("max_results", 10)
|
||||
min_score = args.get("min_score", 0.1)
|
||||
|
||||
if not query:
|
||||
return ToolResult.fail("Error: query parameter is required")
|
||||
|
||||
try:
|
||||
# Run async search in sync context
|
||||
results = asyncio.run(self.memory_manager.search(
|
||||
query=query,
|
||||
user_id=self.user_id,
|
||||
max_results=max_results,
|
||||
min_score=min_score,
|
||||
include_shared=True
|
||||
))
|
||||
|
||||
if not results:
|
||||
# Return clear message that no memories exist yet
|
||||
# This prevents infinite retry loops
|
||||
return ToolResult.success(
|
||||
f"No memories found for '{query}'. "
|
||||
f"This is normal if no memories have been stored yet. "
|
||||
f"You can store new memories by writing to MEMORY.md or memory/YYYY-MM-DD.md files."
|
||||
)
|
||||
|
||||
# Format results
|
||||
output = [f"Found {len(results)} relevant memories:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
|
||||
output.append(f" Score: {result.score:.3f}")
|
||||
output.append(f" Snippet: {result.snippet}")
|
||||
|
||||
return ToolResult.success("\n".join(output))
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error searching memory: {str(e)}")
|
||||
3
agent/tools/read/__init__.py
Normal file
3
agent/tools/read/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .read import Read
|
||||
|
||||
__all__ = ['Read']
|
||||
336
agent/tools/read/read.py
Normal file
336
agent/tools/read/read.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
Read tool - Read file contents
|
||||
Supports text files, images (jpg, png, gif, webp), and PDF files
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
class Read(BaseTool):
|
||||
"""Tool for reading file contents"""
|
||||
|
||||
name: str = "read"
|
||||
description: str = f"Read the contents of a file. Supports text files, PDF files, and images (jpg, png, gif, webp). For text files, output is truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Use offset/limit for large files."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to read (relative or absolute)"
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, optional)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (optional)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
# Supported image formats
|
||||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp'}
|
||||
# Supported PDF format
|
||||
self.pdf_extensions = {'.pdf'}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file read operation
|
||||
|
||||
:param args: Contains file path and optional offset/limit parameters
|
||||
:return: File content or error message
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
offset = args.get("offset")
|
||||
limit = args.get("limit")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable
|
||||
if not os.access(absolute_path, os.R_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable: {path}")
|
||||
|
||||
# Check file type
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
|
||||
# Check if image
|
||||
if file_ext in self.image_extensions:
|
||||
return self._read_image(absolute_path, file_ext)
|
||||
|
||||
# Check if PDF
|
||||
if file_ext in self.pdf_extensions:
|
||||
return self._read_pdf(absolute_path, path, offset, limit)
|
||||
|
||||
# Read text file
|
||||
return self._read_text(absolute_path, path, offset, limit)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult:
|
||||
"""
|
||||
Read image file
|
||||
|
||||
:param absolute_path: Absolute path to the image file
|
||||
:param file_ext: File extension
|
||||
:return: Result containing image information
|
||||
"""
|
||||
try:
|
||||
# Read image file
|
||||
with open(absolute_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
# Get file size
|
||||
file_size = len(image_data)
|
||||
|
||||
# Return image information (actual image data can be base64 encoded when needed)
|
||||
import base64
|
||||
base64_data = base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
# Determine MIME type
|
||||
mime_type_map = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.webp': 'image/webp'
|
||||
}
|
||||
mime_type = mime_type_map.get(file_ext, 'image/jpeg')
|
||||
|
||||
result = {
|
||||
"type": "image",
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"data": base64_data # Base64 encoded image data
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading image file: {str(e)}")
|
||||
|
||||
def _read_text(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read text file
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param display_path: Path to display
|
||||
:param offset: Starting line number (1-indexed)
|
||||
:param limit: Maximum number of lines to read
|
||||
:return: File content or error message
|
||||
"""
|
||||
try:
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
all_lines = content.split('\n')
|
||||
total_file_lines = len(all_lines)
|
||||
|
||||
# Apply offset (if specified)
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
start_line = max(0, offset - 1) # Convert to 0-indexed
|
||||
if start_line >= total_file_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
|
||||
)
|
||||
|
||||
start_line_display = start_line + 1 # For display (1-indexed)
|
||||
|
||||
# If user specified limit, use it
|
||||
selected_content = content
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_file_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
# Apply truncation (considering line count and byte limits)
|
||||
truncation = truncate_head(selected_content)
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
if truncation.first_line_exceeds_limit:
|
||||
# First line exceeds 30KB limit
|
||||
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
|
||||
output_text = f"[Line {start_line_display} is {first_line_size}, exceeds {format_size(DEFAULT_MAX_BYTES)} limit. Use bash tool to read: head -c {DEFAULT_MAX_BYTES} {display_path} | tail -n +{start_line_display}]"
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif truncation.truncated:
|
||||
# Truncation occurred
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
|
||||
output_text = truncation.content
|
||||
|
||||
if truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines}. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_file_lines:
|
||||
# User specified limit, more content available, but no truncation
|
||||
remaining = total_file_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
# No truncation, no exceeding user limit
|
||||
output_text = truncation.content
|
||||
|
||||
result = {
|
||||
"content": output_text,
|
||||
"total_lines": total_file_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines
|
||||
}
|
||||
|
||||
if details:
|
||||
result["details"] = details
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {display_path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading file: {str(e)}")
|
||||
|
||||
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read PDF file content
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param display_path: Path to display
|
||||
:param offset: Starting line number (1-indexed)
|
||||
:param limit: Maximum number of lines to read
|
||||
:return: PDF text content or error message
|
||||
"""
|
||||
try:
|
||||
# Try to import pypdf
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
return ToolResult.fail(
|
||||
"Error: pypdf library not installed. Install with: pip install pypdf"
|
||||
)
|
||||
|
||||
# Read PDF
|
||||
reader = PdfReader(absolute_path)
|
||||
total_pages = len(reader.pages)
|
||||
|
||||
# Extract text from all pages
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num} ---\n{page_text}")
|
||||
|
||||
if not text_parts:
|
||||
return ToolResult.success({
|
||||
"content": f"[PDF file with {total_pages} pages, but no text content could be extracted]",
|
||||
"total_pages": total_pages,
|
||||
"message": "PDF may contain only images or be encrypted"
|
||||
})
|
||||
|
||||
# Merge all text
|
||||
full_content = "\n\n".join(text_parts)
|
||||
all_lines = full_content.split('\n')
|
||||
total_lines = len(all_lines)
|
||||
|
||||
# Apply offset and limit (same logic as text files)
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
start_line = max(0, offset - 1)
|
||||
if start_line >= total_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
|
||||
)
|
||||
|
||||
start_line_display = start_line + 1
|
||||
|
||||
selected_content = full_content
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
# Apply truncation
|
||||
truncation = truncate_head(selected_content)
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
if truncation.truncated:
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
|
||||
output_text = truncation.content
|
||||
|
||||
if truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
|
||||
remaining = total_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text = truncation.content
|
||||
|
||||
result = {
|
||||
"content": output_text,
|
||||
"total_pages": total_pages,
|
||||
"total_lines": total_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines
|
||||
}
|
||||
|
||||
if details:
|
||||
result["details"] = details
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading PDF file: {str(e)}")
|
||||
248
agent/tools/tool_manager.py
Normal file
248
agent/tools/tool_manager.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Type
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""
|
||||
Tool manager for managing tools.
|
||||
"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern to ensure only one instance of ToolManager exists."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ToolManager, cls).__new__(cls)
|
||||
cls._instance.tool_classes = {} # Store tool classes instead of instances
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Initialize only once
|
||||
if not hasattr(self, 'tool_classes'):
|
||||
self.tool_classes = {} # Dictionary to store tool classes
|
||||
|
||||
def load_tools(self, tools_dir: str = "", config_dict=None):
|
||||
"""
|
||||
Load tools from both directory and configuration.
|
||||
|
||||
:param tools_dir: Directory to scan for tool modules
|
||||
"""
|
||||
if tools_dir:
|
||||
self._load_tools_from_directory(tools_dir)
|
||||
self._configure_tools_from_config()
|
||||
else:
|
||||
self._load_tools_from_init()
|
||||
self._configure_tools_from_config(config_dict)
|
||||
|
||||
def _load_tools_from_init(self) -> bool:
|
||||
"""
|
||||
Load tool classes from tools.__init__.__all__
|
||||
|
||||
:return: True if tools were loaded, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to import the tools package
|
||||
tools_package = importlib.import_module("agent.tools")
|
||||
|
||||
# Check if __all__ is defined
|
||||
if hasattr(tools_package, "__all__"):
|
||||
tool_classes = tools_package.__all__
|
||||
|
||||
# Import each tool class directly from the tools package
|
||||
for class_name in tool_classes:
|
||||
try:
|
||||
# Skip base classes
|
||||
if class_name in ["BaseTool", "ToolManager"]:
|
||||
continue
|
||||
|
||||
# Get the class directly from the tools package
|
||||
if hasattr(tools_package, class_name):
|
||||
cls = getattr(tools_package, class_name)
|
||||
|
||||
if (
|
||||
isinstance(cls, type)
|
||||
and issubclass(cls, BaseTool)
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Skip memory tools (they need special initialization with memory_manager)
|
||||
if class_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||
logger.debug(f"Skipped tool {class_name} (requires memory_manager)")
|
||||
continue
|
||||
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
tool_name = temp_instance.name
|
||||
# Store the class, not the instance
|
||||
self.tool_classes[tool_name] = cls
|
||||
logger.debug(f"Loaded tool: {tool_name} from class {class_name}")
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
|
||||
f" Install with: pip install markdownify"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing class {class_name}: {e}")
|
||||
|
||||
return len(self.tool_classes) > 0
|
||||
return False
|
||||
except ImportError:
|
||||
logger.warning("Could not import agent.tools package")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading tools from __init__.__all__: {e}")
|
||||
return False
|
||||
|
||||
def _load_tools_from_directory(self, tools_dir: str):
|
||||
"""Dynamically load tool classes from directory"""
|
||||
tools_path = Path(tools_dir)
|
||||
|
||||
# Traverse all .py files
|
||||
for py_file in tools_path.rglob("*.py"):
|
||||
# Skip initialization files and base tool files
|
||||
if py_file.name in ["__init__.py", "base_tool.py", "tool_manager.py"]:
|
||||
continue
|
||||
|
||||
# Get module name
|
||||
module_name = py_file.stem
|
||||
|
||||
try:
|
||||
# Load module directly from file
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Find tool classes in the module
|
||||
for attr_name in dir(module):
|
||||
cls = getattr(module, attr_name)
|
||||
if (
|
||||
isinstance(cls, type)
|
||||
and issubclass(cls, BaseTool)
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Skip memory tools (they need special initialization with memory_manager)
|
||||
if attr_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||
logger.debug(f"Skipped tool {attr_name} (requires memory_manager)")
|
||||
continue
|
||||
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
tool_name = temp_instance.name
|
||||
# Store the class, not the instance
|
||||
self.tool_classes[tool_name] = cls
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
|
||||
f" Install with: pip install markdownify"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error importing module {py_file}: {e}")
|
||||
|
||||
def _configure_tools_from_config(self, config_dict=None):
|
||||
"""Configure tool classes based on configuration file"""
|
||||
try:
|
||||
# Get tools configuration
|
||||
tools_config = config_dict or conf().get("tools", {})
|
||||
|
||||
# Record tools that are configured but not loaded
|
||||
missing_tools = []
|
||||
|
||||
# Store configurations for later use when instantiating
|
||||
self.tool_configs = tools_config
|
||||
|
||||
# Check which configured tools are missing
|
||||
for tool_name in tools_config:
|
||||
if tool_name not in self.tool_classes:
|
||||
missing_tools.append(tool_name)
|
||||
|
||||
# If there are missing tools, record warnings
|
||||
if missing_tools:
|
||||
for tool_name in missing_tools:
|
||||
if tool_name == "browser":
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool is configured but not loaded.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif tool_name == "google_search":
|
||||
logger.warning(
|
||||
f"[ToolManager] Google Search tool is configured but may need API key.\n"
|
||||
f" Get API key from: https://serper.dev\n"
|
||||
f" Configure in config.json: tools.google_search.api_key"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] Tool '{tool_name}' is configured but could not be loaded.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error configuring tools from config: {e}")
|
||||
|
||||
def create_tool(self, name: str) -> BaseTool:
|
||||
"""
|
||||
Get a new instance of a tool by name.
|
||||
|
||||
:param name: The name of the tool to get.
|
||||
:return: A new instance of the tool or None if not found.
|
||||
"""
|
||||
tool_class = self.tool_classes.get(name)
|
||||
if tool_class:
|
||||
# Create a new instance
|
||||
tool_instance = tool_class()
|
||||
|
||||
# Apply configuration if available
|
||||
if hasattr(self, 'tool_configs') and name in self.tool_configs:
|
||||
tool_instance.config = self.tool_configs[name]
|
||||
|
||||
return tool_instance
|
||||
return None
|
||||
|
||||
def list_tools(self) -> dict:
|
||||
"""
|
||||
Get information about all loaded tools.
|
||||
|
||||
:return: A dictionary with tool information.
|
||||
"""
|
||||
result = {}
|
||||
for name, tool_class in self.tool_classes.items():
|
||||
# Create a temporary instance to get schema
|
||||
temp_instance = tool_class()
|
||||
result[name] = {
|
||||
"description": temp_instance.description,
|
||||
"parameters": temp_instance.get_json_schema()
|
||||
}
|
||||
return result
|
||||
40
agent/tools/utils/__init__.py
Normal file
40
agent/tools/utils/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from .truncate import (
|
||||
truncate_head,
|
||||
truncate_tail,
|
||||
truncate_line,
|
||||
format_size,
|
||||
TruncationResult,
|
||||
DEFAULT_MAX_LINES,
|
||||
DEFAULT_MAX_BYTES,
|
||||
GREP_MAX_LINE_LENGTH
|
||||
)
|
||||
|
||||
from .diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
normalize_to_lf,
|
||||
restore_line_endings,
|
||||
normalize_for_fuzzy_match,
|
||||
fuzzy_find_text,
|
||||
generate_diff_string,
|
||||
FuzzyMatchResult
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'truncate_head',
|
||||
'truncate_tail',
|
||||
'truncate_line',
|
||||
'format_size',
|
||||
'TruncationResult',
|
||||
'DEFAULT_MAX_LINES',
|
||||
'DEFAULT_MAX_BYTES',
|
||||
'GREP_MAX_LINE_LENGTH',
|
||||
'strip_bom',
|
||||
'detect_line_ending',
|
||||
'normalize_to_lf',
|
||||
'restore_line_endings',
|
||||
'normalize_for_fuzzy_match',
|
||||
'fuzzy_find_text',
|
||||
'generate_diff_string',
|
||||
'FuzzyMatchResult'
|
||||
]
|
||||
167
agent/tools/utils/diff.py
Normal file
167
agent/tools/utils/diff.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Diff tools for file editing
|
||||
Provides fuzzy matching and diff generation functionality
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def strip_bom(text: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Remove BOM (Byte Order Mark)
|
||||
|
||||
:param text: Original text
|
||||
:return: (BOM, text after removing BOM)
|
||||
"""
|
||||
if text.startswith('\ufeff'):
|
||||
return '\ufeff', text[1:]
|
||||
return '', text
|
||||
|
||||
|
||||
def detect_line_ending(text: str) -> str:
|
||||
"""
|
||||
Detect line ending type
|
||||
|
||||
:param text: Text content
|
||||
:return: Line ending type ('\r\n' or '\n')
|
||||
"""
|
||||
if '\r\n' in text:
|
||||
return '\r\n'
|
||||
return '\n'
|
||||
|
||||
|
||||
def normalize_to_lf(text: str) -> str:
|
||||
"""
|
||||
Normalize all line endings to LF (\n)
|
||||
|
||||
:param text: Original text
|
||||
:return: Normalized text
|
||||
"""
|
||||
return text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
|
||||
def restore_line_endings(text: str, original_ending: str) -> str:
|
||||
"""
|
||||
Restore original line endings
|
||||
|
||||
:param text: LF normalized text
|
||||
:param original_ending: Original line ending
|
||||
:return: Text with restored line endings
|
||||
"""
|
||||
if original_ending == '\r\n':
|
||||
return text.replace('\n', '\r\n')
|
||||
return text
|
||||
|
||||
|
||||
def normalize_for_fuzzy_match(text: str) -> str:
|
||||
"""
|
||||
Normalize text for fuzzy matching
|
||||
Remove excess whitespace but preserve basic structure
|
||||
|
||||
:param text: Original text
|
||||
:return: Normalized text
|
||||
"""
|
||||
# Compress multiple spaces to one
|
||||
text = re.sub(r'[ \t]+', ' ', text)
|
||||
# Remove trailing spaces
|
||||
text = re.sub(r' +\n', '\n', text)
|
||||
# Remove leading spaces (but preserve indentation structure, only remove excess)
|
||||
lines = text.split('\n')
|
||||
normalized_lines = []
|
||||
for line in lines:
|
||||
# Preserve indentation but normalize to multiples of single spaces
|
||||
stripped = line.lstrip()
|
||||
if stripped:
|
||||
indent_count = len(line) - len(stripped)
|
||||
# Normalize indentation (convert tabs to spaces)
|
||||
normalized_indent = ' ' * indent_count
|
||||
normalized_lines.append(normalized_indent + stripped)
|
||||
else:
|
||||
normalized_lines.append('')
|
||||
return '\n'.join(normalized_lines)
|
||||
|
||||
|
||||
class FuzzyMatchResult:
|
||||
"""Fuzzy match result"""
|
||||
|
||||
def __init__(self, found: bool, index: int = -1, match_length: int = 0, content_for_replacement: str = ""):
|
||||
self.found = found
|
||||
self.index = index
|
||||
self.match_length = match_length
|
||||
self.content_for_replacement = content_for_replacement
|
||||
|
||||
|
||||
def fuzzy_find_text(content: str, old_text: str) -> FuzzyMatchResult:
|
||||
"""
|
||||
Find text in content, try exact match first, then fuzzy match
|
||||
|
||||
:param content: Content to search in
|
||||
:param old_text: Text to find
|
||||
:return: Match result
|
||||
"""
|
||||
# First try exact match
|
||||
index = content.find(old_text)
|
||||
if index != -1:
|
||||
return FuzzyMatchResult(
|
||||
found=True,
|
||||
index=index,
|
||||
match_length=len(old_text),
|
||||
content_for_replacement=content
|
||||
)
|
||||
|
||||
# Try fuzzy match
|
||||
fuzzy_content = normalize_for_fuzzy_match(content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(old_text)
|
||||
|
||||
index = fuzzy_content.find(fuzzy_old_text)
|
||||
if index != -1:
|
||||
# Fuzzy match successful, use normalized content for replacement
|
||||
return FuzzyMatchResult(
|
||||
found=True,
|
||||
index=index,
|
||||
match_length=len(fuzzy_old_text),
|
||||
content_for_replacement=fuzzy_content
|
||||
)
|
||||
|
||||
# Not found
|
||||
return FuzzyMatchResult(found=False)
|
||||
|
||||
|
||||
def generate_diff_string(old_content: str, new_content: str) -> dict:
|
||||
"""
|
||||
Generate unified diff string
|
||||
|
||||
:param old_content: Old content
|
||||
:param new_content: New content
|
||||
:return: Dictionary containing diff and first changed line number
|
||||
"""
|
||||
old_lines = old_content.split('\n')
|
||||
new_lines = new_content.split('\n')
|
||||
|
||||
# Generate unified diff
|
||||
diff_lines = list(difflib.unified_diff(
|
||||
old_lines,
|
||||
new_lines,
|
||||
lineterm='',
|
||||
fromfile='original',
|
||||
tofile='modified'
|
||||
))
|
||||
|
||||
# Find first changed line number
|
||||
first_changed_line = None
|
||||
for line in diff_lines:
|
||||
if line.startswith('@@'):
|
||||
# Parse @@ -1,3 +1,3 @@ format
|
||||
match = re.search(r'@@ -\d+,?\d* \+(\d+)', line)
|
||||
if match:
|
||||
first_changed_line = int(match.group(1))
|
||||
break
|
||||
|
||||
diff_string = '\n'.join(diff_lines)
|
||||
|
||||
return {
|
||||
'diff': diff_string,
|
||||
'first_changed_line': first_changed_line
|
||||
}
|
||||
292
agent/tools/utils/truncate.py
Normal file
292
agent/tools/utils/truncate.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
Shared truncation utilities for tool outputs.
|
||||
|
||||
Truncation is based on two independent limits - whichever is hit first wins:
|
||||
- Line limit (default: 2000 lines)
|
||||
- Byte limit (default: 50KB)
|
||||
|
||||
Never returns partial lines (except bash tail truncation edge case).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Literal
|
||||
|
||||
|
||||
DEFAULT_MAX_LINES = 2000
|
||||
DEFAULT_MAX_BYTES = 50 * 1024 # 50KB
|
||||
GREP_MAX_LINE_LENGTH = 500 # Max chars per grep match line
|
||||
|
||||
|
||||
class TruncationResult:
|
||||
"""Truncation result"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
truncated: bool,
|
||||
truncated_by: Optional[Literal["lines", "bytes"]],
|
||||
total_lines: int,
|
||||
total_bytes: int,
|
||||
output_lines: int,
|
||||
output_bytes: int,
|
||||
last_line_partial: bool = False,
|
||||
first_line_exceeds_limit: bool = False,
|
||||
max_lines: int = DEFAULT_MAX_LINES,
|
||||
max_bytes: int = DEFAULT_MAX_BYTES
|
||||
):
|
||||
self.content = content
|
||||
self.truncated = truncated
|
||||
self.truncated_by = truncated_by
|
||||
self.total_lines = total_lines
|
||||
self.total_bytes = total_bytes
|
||||
self.output_lines = output_lines
|
||||
self.output_bytes = output_bytes
|
||||
self.last_line_partial = last_line_partial
|
||||
self.first_line_exceeds_limit = first_line_exceeds_limit
|
||||
self.max_lines = max_lines
|
||||
self.max_bytes = max_bytes
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"content": self.content,
|
||||
"truncated": self.truncated,
|
||||
"truncated_by": self.truncated_by,
|
||||
"total_lines": self.total_lines,
|
||||
"total_bytes": self.total_bytes,
|
||||
"output_lines": self.output_lines,
|
||||
"output_bytes": self.output_bytes,
|
||||
"last_line_partial": self.last_line_partial,
|
||||
"first_line_exceeds_limit": self.first_line_exceeds_limit,
|
||||
"max_lines": self.max_lines,
|
||||
"max_bytes": self.max_bytes
|
||||
}
|
||||
|
||||
|
||||
def format_size(bytes_count: int) -> str:
|
||||
"""Format bytes as human-readable size"""
|
||||
if bytes_count < 1024:
|
||||
return f"{bytes_count}B"
|
||||
elif bytes_count < 1024 * 1024:
|
||||
return f"{bytes_count / 1024:.1f}KB"
|
||||
else:
|
||||
return f"{bytes_count / (1024 * 1024):.1f}MB"
|
||||
|
||||
|
||||
def truncate_head(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||
"""
|
||||
Truncate content from the head (keep first N lines/bytes).
|
||||
Suitable for file reads where you want to see the beginning.
|
||||
|
||||
Never returns partial lines. If first line exceeds byte limit,
|
||||
returns empty content with first_line_exceeds_limit=True.
|
||||
|
||||
:param content: Content to truncate
|
||||
:param max_lines: Maximum number of lines (default: 2000)
|
||||
:param max_bytes: Maximum number of bytes (default: 50KB)
|
||||
:return: Truncation result
|
||||
"""
|
||||
if max_lines is None:
|
||||
max_lines = DEFAULT_MAX_LINES
|
||||
if max_bytes is None:
|
||||
max_bytes = DEFAULT_MAX_BYTES
|
||||
|
||||
total_bytes = len(content.encode('utf-8'))
|
||||
lines = content.split('\n')
|
||||
total_lines = len(lines)
|
||||
|
||||
# Check if no truncation is needed
|
||||
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||
return TruncationResult(
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncated_by=None,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=total_lines,
|
||||
output_bytes=total_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Check if first line alone exceeds byte limit
|
||||
first_line_bytes = len(lines[0].encode('utf-8'))
|
||||
if first_line_bytes > max_bytes:
|
||||
return TruncationResult(
|
||||
content="",
|
||||
truncated=True,
|
||||
truncated_by="bytes",
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=0,
|
||||
output_bytes=0,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=True,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Collect complete lines that fit
|
||||
output_lines_arr = []
|
||||
output_bytes_count = 0
|
||||
truncated_by = "lines"
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if i >= max_lines:
|
||||
break
|
||||
|
||||
# Calculate line bytes (add 1 for newline if not first line)
|
||||
line_bytes = len(line.encode('utf-8')) + (1 if i > 0 else 0)
|
||||
|
||||
if output_bytes_count + line_bytes > max_bytes:
|
||||
truncated_by = "bytes"
|
||||
break
|
||||
|
||||
output_lines_arr.append(line)
|
||||
output_bytes_count += line_bytes
|
||||
|
||||
# If exited due to line limit
|
||||
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||
truncated_by = "lines"
|
||||
|
||||
output_content = '\n'.join(output_lines_arr)
|
||||
final_output_bytes = len(output_content.encode('utf-8'))
|
||||
|
||||
return TruncationResult(
|
||||
content=output_content,
|
||||
truncated=True,
|
||||
truncated_by=truncated_by,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=len(output_lines_arr),
|
||||
output_bytes=final_output_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
|
||||
def truncate_tail(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||
"""
|
||||
Truncate content from tail (keep last N lines/bytes).
|
||||
Suitable for bash output where you want to see the ending content (errors, final results).
|
||||
|
||||
If the last line of original content exceeds byte limit, may return partial first line.
|
||||
|
||||
:param content: Content to truncate
|
||||
:param max_lines: Maximum lines (default: 2000)
|
||||
:param max_bytes: Maximum bytes (default: 50KB)
|
||||
:return: Truncation result
|
||||
"""
|
||||
if max_lines is None:
|
||||
max_lines = DEFAULT_MAX_LINES
|
||||
if max_bytes is None:
|
||||
max_bytes = DEFAULT_MAX_BYTES
|
||||
|
||||
total_bytes = len(content.encode('utf-8'))
|
||||
lines = content.split('\n')
|
||||
total_lines = len(lines)
|
||||
|
||||
# Check if no truncation is needed
|
||||
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||
return TruncationResult(
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncated_by=None,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=total_lines,
|
||||
output_bytes=total_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Work backwards from the end
|
||||
output_lines_arr = []
|
||||
output_bytes_count = 0
|
||||
truncated_by = "lines"
|
||||
last_line_partial = False
|
||||
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
if len(output_lines_arr) >= max_lines:
|
||||
break
|
||||
|
||||
line = lines[i]
|
||||
# Calculate line bytes (add newline if not the first added line)
|
||||
line_bytes = len(line.encode('utf-8')) + (1 if len(output_lines_arr) > 0 else 0)
|
||||
|
||||
if output_bytes_count + line_bytes > max_bytes:
|
||||
truncated_by = "bytes"
|
||||
# Edge case: if we haven't added any lines yet and this line exceeds maxBytes,
|
||||
# take the end portion of this line
|
||||
if len(output_lines_arr) == 0:
|
||||
truncated_line = _truncate_string_to_bytes_from_end(line, max_bytes)
|
||||
output_lines_arr.insert(0, truncated_line)
|
||||
output_bytes_count = len(truncated_line.encode('utf-8'))
|
||||
last_line_partial = True
|
||||
break
|
||||
|
||||
output_lines_arr.insert(0, line)
|
||||
output_bytes_count += line_bytes
|
||||
|
||||
# If exited due to line limit
|
||||
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||
truncated_by = "lines"
|
||||
|
||||
output_content = '\n'.join(output_lines_arr)
|
||||
final_output_bytes = len(output_content.encode('utf-8'))
|
||||
|
||||
return TruncationResult(
|
||||
content=output_content,
|
||||
truncated=True,
|
||||
truncated_by=truncated_by,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=len(output_lines_arr),
|
||||
output_bytes=final_output_bytes,
|
||||
last_line_partial=last_line_partial,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
|
||||
def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
|
||||
"""
|
||||
Truncate string to fit byte limit (from end).
|
||||
Properly handles multi-byte UTF-8 characters.
|
||||
|
||||
:param text: String to truncate
|
||||
:param max_bytes: Maximum bytes
|
||||
:return: Truncated string
|
||||
"""
|
||||
encoded = text.encode('utf-8')
|
||||
if len(encoded) <= max_bytes:
|
||||
return text
|
||||
|
||||
# Start from end, skip back maxBytes
|
||||
start = len(encoded) - max_bytes
|
||||
|
||||
# Find valid UTF-8 boundary (character start)
|
||||
while start < len(encoded) and (encoded[start] & 0xC0) == 0x80:
|
||||
start += 1
|
||||
|
||||
return encoded[start:].decode('utf-8', errors='ignore')
|
||||
|
||||
|
||||
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> tuple[str, bool]:
|
||||
"""
|
||||
Truncate single line to max characters, add [truncated] suffix.
|
||||
Used for grep match lines.
|
||||
|
||||
:param line: Line to truncate
|
||||
:param max_chars: Maximum characters
|
||||
:return: (truncated text, whether truncated)
|
||||
"""
|
||||
if len(line) <= max_chars:
|
||||
return line, False
|
||||
return f"{line[:max_chars]}... [truncated]", True
|
||||
212
agent/tools/web_fetch/README.md
Normal file
212
agent/tools/web_fetch/README.md
Normal file
@@ -0,0 +1,212 @@
|
||||
# WebFetch Tool
|
||||
|
||||
免费的网页抓取工具,无需 API Key,可直接抓取网页内容并提取可读文本。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- ✅ **完全免费** - 无需任何 API Key
|
||||
- 🌐 **智能提取** - 自动提取网页主要内容
|
||||
- 📝 **格式转换** - 支持 HTML → Markdown/Text
|
||||
- 🚀 **高性能** - 内置请求重试和超时控制
|
||||
- 🎯 **智能降级** - 优先使用 Readability,可降级到基础提取
|
||||
|
||||
## 安装依赖
|
||||
|
||||
### 基础功能(必需)
|
||||
```bash
|
||||
pip install requests
|
||||
```
|
||||
|
||||
### 增强功能(推荐)
|
||||
```bash
|
||||
# 安装 readability-lxml 以获得更好的内容提取效果
|
||||
pip install readability-lxml
|
||||
|
||||
# 安装 html2text 以获得更好的 Markdown 转换
|
||||
pip install html2text
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 在代码中使用
|
||||
|
||||
```python
|
||||
from agent.tools.web_fetch import WebFetch
|
||||
|
||||
# 创建工具实例
|
||||
tool = WebFetch()
|
||||
|
||||
# 抓取网页(默认返回 Markdown 格式)
|
||||
result = tool.execute({
|
||||
"url": "https://example.com"
|
||||
})
|
||||
|
||||
# 抓取并转换为纯文本
|
||||
result = tool.execute({
|
||||
"url": "https://example.com",
|
||||
"extract_mode": "text",
|
||||
"max_chars": 5000
|
||||
})
|
||||
|
||||
if result.status == "success":
|
||||
data = result.result
|
||||
print(f"标题: {data['title']}")
|
||||
print(f"内容: {data['text']}")
|
||||
```
|
||||
|
||||
### 2. 在 Agent 中使用
|
||||
|
||||
工具会自动加载到 Agent 的工具列表中:
|
||||
|
||||
```python
|
||||
from agent.tools import WebFetch
|
||||
|
||||
tools = [
|
||||
WebFetch(),
|
||||
# ... 其他工具
|
||||
]
|
||||
|
||||
agent = create_agent(tools=tools)
|
||||
```
|
||||
|
||||
### 3. 通过 Skills 使用
|
||||
|
||||
创建一个 skill 文件 `skills/web-fetch/SKILL.md`:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: web-fetch
|
||||
emoji: 🌐
|
||||
always: true
|
||||
---
|
||||
|
||||
# 网页内容获取
|
||||
|
||||
使用 web_fetch 工具获取网页内容。
|
||||
|
||||
## 使用场景
|
||||
|
||||
- 需要读取某个网页的内容
|
||||
- 需要提取文章正文
|
||||
- 需要获取网页信息
|
||||
|
||||
## 示例
|
||||
|
||||
<example>
|
||||
用户: 帮我看看 https://example.com 这个网页讲了什么
|
||||
助手: <tool_use name="web_fetch">
|
||||
<url>https://example.com</url>
|
||||
<extract_mode>markdown</extract_mode>
|
||||
</tool_use>
|
||||
</example>
|
||||
```
|
||||
|
||||
## 参数说明
|
||||
|
||||
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| `url` | string | ✅ | - | 要抓取的 URL(http/https) |
|
||||
| `extract_mode` | string | ❌ | `markdown` | 提取模式:`markdown` 或 `text` |
|
||||
| `max_chars` | integer | ❌ | `50000` | 最大返回字符数(最小 100) |
|
||||
|
||||
## 返回结果
|
||||
|
||||
```python
|
||||
{
|
||||
"url": "https://example.com", # 最终 URL(处理重定向后)
|
||||
"status": 200, # HTTP 状态码
|
||||
"content_type": "text/html", # 内容类型
|
||||
"title": "Example Domain", # 页面标题
|
||||
"extractor": "readability", # 提取器:readability/basic/raw
|
||||
"extract_mode": "markdown", # 提取模式
|
||||
"text": "# Example Domain\n\n...", # 提取的文本内容
|
||||
"length": 1234, # 文本长度
|
||||
"truncated": false, # 是否被截断
|
||||
"warning": "..." # 警告信息(如果有)
|
||||
}
|
||||
```
|
||||
|
||||
## 与其他搜索工具的对比
|
||||
|
||||
| 工具 | 需要 API Key | 功能 | 成本 |
|
||||
|------|-------------|------|------|
|
||||
| `web_fetch` | ❌ 不需要 | 抓取指定 URL 的内容 | 免费 |
|
||||
| `web_search` (Brave) | ✅ 需要 | 搜索引擎查询 | 有免费额度 |
|
||||
| `web_search` (Perplexity) | ✅ 需要 | AI 搜索 + 引用 | 付费 |
|
||||
| `browser` | ❌ 不需要 | 完整浏览器自动化 | 免费但资源占用大 |
|
||||
| `google_search` | ✅ 需要 | Google 搜索 API | 付费 |
|
||||
|
||||
## 技术细节
|
||||
|
||||
### 内容提取策略
|
||||
|
||||
1. **Readability 模式**(推荐)
|
||||
- 使用 Mozilla 的 Readability 算法
|
||||
- 自动识别文章主体内容
|
||||
- 过滤广告、导航栏等噪音
|
||||
|
||||
2. **Basic 模式**(降级)
|
||||
- 简单的 HTML 标签清理
|
||||
- 正则表达式提取文本
|
||||
- 适用于简单页面
|
||||
|
||||
3. **Raw 模式**
|
||||
- 用于非 HTML 内容
|
||||
- 直接返回原始内容
|
||||
|
||||
### 错误处理
|
||||
|
||||
工具会自动处理以下情况:
|
||||
- ✅ HTTP 重定向(最多 3 次)
|
||||
- ✅ 请求超时(默认 30 秒)
|
||||
- ✅ 网络错误自动重试
|
||||
- ✅ 内容提取失败降级
|
||||
|
||||
## 测试
|
||||
|
||||
运行测试脚本:
|
||||
|
||||
```bash
|
||||
cd agent/tools/web_fetch
|
||||
python test_web_fetch.py
|
||||
```
|
||||
|
||||
## 配置选项
|
||||
|
||||
在创建工具时可以传入配置:
|
||||
|
||||
```python
|
||||
tool = WebFetch(config={
|
||||
"timeout": 30, # 请求超时时间(秒)
|
||||
"max_redirects": 3, # 最大重定向次数
|
||||
"user_agent": "..." # 自定义 User-Agent
|
||||
})
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q: 为什么推荐安装 readability-lxml?
|
||||
|
||||
A: readability-lxml 提供更好的内容提取质量,能够:
|
||||
- 自动识别文章主体
|
||||
- 过滤广告和导航栏
|
||||
- 保留文章结构
|
||||
|
||||
没有它也能工作,但提取质量会下降。
|
||||
|
||||
### Q: 与 clawdbot 的 web_fetch 有什么区别?
|
||||
|
||||
A: 本实现参考了 clawdbot 的设计,主要区别:
|
||||
- Python 实现(clawdbot 是 TypeScript)
|
||||
- 简化了一些高级特性(如 Firecrawl 集成)
|
||||
- 保留了核心的免费功能
|
||||
- 更容易集成到现有项目
|
||||
|
||||
### Q: 可以抓取需要登录的页面吗?
|
||||
|
||||
A: 当前版本不支持。如需抓取需要登录的页面,请使用 `browser` 工具。
|
||||
|
||||
## 参考
|
||||
|
||||
- [Mozilla Readability](https://github.com/mozilla/readability)
|
||||
- [Clawdbot Web Tools](https://github.com/moltbot/moltbot)
|
||||
3
agent/tools/web_fetch/__init__.py
Normal file
3
agent/tools/web_fetch/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .web_fetch import WebFetch
|
||||
|
||||
__all__ = ['WebFetch']
|
||||
47
agent/tools/web_fetch/install_deps.sh
Normal file
47
agent/tools/web_fetch/install_deps.sh
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
|
||||
# WebFetch 工具依赖安装脚本
|
||||
|
||||
echo "=================================="
|
||||
echo "WebFetch 工具依赖安装"
|
||||
echo "=================================="
|
||||
echo ""
|
||||
|
||||
# 检查 Python 版本
|
||||
python_version=$(python3 --version 2>&1 | awk '{print $2}')
|
||||
echo "✓ Python 版本: $python_version"
|
||||
echo ""
|
||||
|
||||
# 安装基础依赖
|
||||
echo "📦 安装基础依赖..."
|
||||
python3 -m pip install requests
|
||||
|
||||
# 检查是否成功
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ requests 安装成功"
|
||||
else
|
||||
echo "❌ requests 安装失败"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# 安装推荐依赖
|
||||
echo "📦 安装推荐依赖(提升内容提取质量)..."
|
||||
python3 -m pip install readability-lxml html2text
|
||||
|
||||
# 检查是否成功
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ readability-lxml 和 html2text 安装成功"
|
||||
else
|
||||
echo "⚠️ 推荐依赖安装失败,但不影响基础功能"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=================================="
|
||||
echo "安装完成!"
|
||||
echo "=================================="
|
||||
echo ""
|
||||
echo "运行测试:"
|
||||
echo " python3 agent/tools/web_fetch/test_web_fetch.py"
|
||||
echo ""
|
||||
365
agent/tools/web_fetch/web_fetch.py
Normal file
365
agent/tools/web_fetch/web_fetch.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""
|
||||
Web Fetch tool - Fetch and extract readable content from URLs
|
||||
Supports HTML to Markdown/Text conversion using Mozilla's Readability
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class WebFetch(BaseTool):
|
||||
"""Tool for fetching and extracting readable content from web pages"""
|
||||
|
||||
name: str = "web_fetch"
|
||||
description: str = "Fetch and extract readable content from a URL (HTML → markdown/text). Use for lightweight page access without browser automation. Returns title, content, and metadata."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "HTTP or HTTPS URL to fetch"
|
||||
},
|
||||
"extract_mode": {
|
||||
"type": "string",
|
||||
"description": "Extraction mode: 'markdown' (default) or 'text'",
|
||||
"enum": ["markdown", "text"],
|
||||
"default": "markdown"
|
||||
},
|
||||
"max_chars": {
|
||||
"type": "integer",
|
||||
"description": "Maximum characters to return (default: 50000)",
|
||||
"minimum": 100,
|
||||
"default": 50000
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.timeout = self.config.get("timeout", 30)
|
||||
self.max_redirects = self.config.get("max_redirects", 3)
|
||||
self.user_agent = self.config.get(
|
||||
"user_agent",
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
# Setup session with retry strategy
|
||||
self.session = self._create_session()
|
||||
|
||||
# Check if readability-lxml is available
|
||||
self.readability_available = self._check_readability()
|
||||
|
||||
def _create_session(self) -> requests.Session:
|
||||
"""Create a requests session with retry strategy"""
|
||||
session = requests.Session()
|
||||
|
||||
# Retry strategy - handles failed requests, not redirects
|
||||
retry_strategy = Retry(
|
||||
total=3,
|
||||
backoff_factor=1,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
allowed_methods=["GET", "HEAD"]
|
||||
)
|
||||
|
||||
# HTTPAdapter handles retries; requests handles redirects via allow_redirects
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
# Set max redirects on session
|
||||
session.max_redirects = self.max_redirects
|
||||
|
||||
return session
|
||||
|
||||
def _check_readability(self) -> bool:
|
||||
"""Check if readability-lxml is available"""
|
||||
try:
|
||||
from readability import Document
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"readability-lxml not installed. Install with: pip install readability-lxml\n"
|
||||
"Falling back to basic HTML extraction."
|
||||
)
|
||||
return False
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute web fetch operation
|
||||
|
||||
:param args: Contains url, extract_mode, and max_chars parameters
|
||||
:return: Extracted content or error message
|
||||
"""
|
||||
url = args.get("url", "").strip()
|
||||
extract_mode = args.get("extract_mode", "markdown").lower()
|
||||
max_chars = args.get("max_chars", 50000)
|
||||
|
||||
if not url:
|
||||
return ToolResult.fail("Error: url parameter is required")
|
||||
|
||||
# Validate URL
|
||||
if not self._is_valid_url(url):
|
||||
return ToolResult.fail(f"Error: Invalid URL (must be http or https): {url}")
|
||||
|
||||
# Validate extract_mode
|
||||
if extract_mode not in ["markdown", "text"]:
|
||||
extract_mode = "markdown"
|
||||
|
||||
# Validate max_chars
|
||||
if not isinstance(max_chars, int) or max_chars < 100:
|
||||
max_chars = 50000
|
||||
|
||||
try:
|
||||
# Fetch the URL
|
||||
response = self._fetch_url(url)
|
||||
|
||||
# Extract content
|
||||
result = self._extract_content(
|
||||
html=response.text,
|
||||
url=response.url,
|
||||
status_code=response.status_code,
|
||||
content_type=response.headers.get("content-type", ""),
|
||||
extract_mode=extract_mode,
|
||||
max_chars=max_chars
|
||||
)
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
return ToolResult.fail(f"Error: Request timeout after {self.timeout} seconds")
|
||||
except requests.exceptions.TooManyRedirects:
|
||||
return ToolResult.fail(f"Error: Too many redirects (limit: {self.max_redirects})")
|
||||
except requests.exceptions.RequestException as e:
|
||||
return ToolResult.fail(f"Error fetching URL: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Web fetch error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: {str(e)}")
|
||||
|
||||
def _is_valid_url(self, url: str) -> bool:
|
||||
"""Validate URL format"""
|
||||
try:
|
||||
result = urlparse(url)
|
||||
return result.scheme in ["http", "https"] and bool(result.netloc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _fetch_url(self, url: str) -> requests.Response:
|
||||
"""
|
||||
Fetch URL with proper headers and error handling
|
||||
|
||||
:param url: URL to fetch
|
||||
:return: Response object
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": self.user_agent,
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
"Accept-Language": "en-US,en;q=0.9,zh-CN,zh;q=0.8",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
# Note: requests library handles redirects automatically
|
||||
# The max_redirects is set in the session's adapter (HTTPAdapter)
|
||||
response = self.session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
allow_redirects=True
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _extract_content(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract readable content from HTML
|
||||
|
||||
:param html: HTML content
|
||||
:param url: Original URL
|
||||
:param status_code: HTTP status code
|
||||
:param content_type: Content type header
|
||||
:param extract_mode: 'markdown' or 'text'
|
||||
:param max_chars: Maximum characters to return
|
||||
:return: Extracted content and metadata
|
||||
"""
|
||||
# Check content type
|
||||
if "text/html" not in content_type.lower():
|
||||
# Non-HTML content
|
||||
text = html[:max_chars]
|
||||
truncated = len(html) > max_chars
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"extractor": "raw",
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated,
|
||||
"message": f"Non-HTML content (type: {content_type})"
|
||||
}
|
||||
|
||||
# Extract readable content from HTML
|
||||
if self.readability_available:
|
||||
return self._extract_with_readability(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
else:
|
||||
return self._extract_basic(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
|
||||
def _extract_with_readability(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Extract content using Mozilla's Readability"""
|
||||
try:
|
||||
from readability import Document
|
||||
|
||||
# Parse with Readability
|
||||
doc = Document(html)
|
||||
title = doc.title()
|
||||
content_html = doc.summary()
|
||||
|
||||
# Convert to markdown or text
|
||||
if extract_mode == "markdown":
|
||||
text = self._html_to_markdown(content_html)
|
||||
else:
|
||||
text = self._html_to_text(content_html)
|
||||
|
||||
# Truncate if needed
|
||||
truncated = len(text) > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"extractor": "readability",
|
||||
"extract_mode": extract_mode,
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Readability extraction failed: {e}")
|
||||
# Fallback to basic extraction
|
||||
return self._extract_basic(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
|
||||
def _extract_basic(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Basic HTML extraction without Readability"""
|
||||
# Extract title
|
||||
title_match = re.search(r'<title[^>]*>(.*?)</title>', html, re.IGNORECASE | re.DOTALL)
|
||||
title = title_match.group(1).strip() if title_match else "Untitled"
|
||||
|
||||
# Remove script and style tags
|
||||
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
# Remove HTML tags
|
||||
text = re.sub(r'<[^>]+>', ' ', text)
|
||||
|
||||
# Clean up whitespace
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
|
||||
# Truncate if needed
|
||||
truncated = len(text) > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"extractor": "basic",
|
||||
"extract_mode": extract_mode,
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated,
|
||||
"warning": "Using basic extraction. Install readability-lxml for better results."
|
||||
}
|
||||
|
||||
def _html_to_markdown(self, html: str) -> str:
|
||||
"""Convert HTML to Markdown (basic implementation)"""
|
||||
try:
|
||||
# Try to use html2text if available
|
||||
import html2text
|
||||
h = html2text.HTML2Text()
|
||||
h.ignore_links = False
|
||||
h.ignore_images = False
|
||||
h.body_width = 0 # Don't wrap lines
|
||||
return h.handle(html)
|
||||
except ImportError:
|
||||
# Fallback to basic conversion
|
||||
return self._html_to_text(html)
|
||||
|
||||
def _html_to_text(self, html: str) -> str:
|
||||
"""Convert HTML to plain text"""
|
||||
# Remove script and style tags
|
||||
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
# Convert common tags to text equivalents
|
||||
text = re.sub(r'<br\s*/?>', '\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'<p[^>]*>', '\n\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'</p>', '', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'<h[1-6][^>]*>', '\n\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'</h[1-6]>', '\n', text, flags=re.IGNORECASE)
|
||||
|
||||
# Remove all other HTML tags
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
|
||||
# Decode HTML entities
|
||||
import html
|
||||
text = html.unescape(text)
|
||||
|
||||
# Clean up whitespace
|
||||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
text = text.strip()
|
||||
|
||||
return text
|
||||
|
||||
def close(self):
|
||||
"""Close the session"""
|
||||
if hasattr(self, 'session'):
|
||||
self.session.close()
|
||||
3
agent/tools/write/__init__.py
Normal file
3
agent/tools/write/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .write import Write
|
||||
|
||||
__all__ = ['Write']
|
||||
96
agent/tools/write/write.py
Normal file
96
agent/tools/write/write.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Write tool - Write file content
|
||||
Creates or overwrites files, automatically creates parent directories
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class Write(BaseTool):
|
||||
"""Tool for writing file content"""
|
||||
|
||||
name: str = "write"
|
||||
description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to write (relative or absolute)"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.memory_manager = self.config.get("memory_manager", None)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file write operation
|
||||
|
||||
:param args: Contains file path and content
|
||||
:return: Operation result
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
content = args.get("content", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
try:
|
||||
# Create parent directory (if needed)
|
||||
parent_dir = os.path.dirname(absolute_path)
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
# Write file
|
||||
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
# Get bytes written
|
||||
bytes_written = len(content.encode('utf-8'))
|
||||
|
||||
# Auto-sync to memory database if this is a memory file
|
||||
if self.memory_manager and 'memory/' in path:
|
||||
self.memory_manager.mark_dirty()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully wrote {bytes_written} bytes to {path}",
|
||||
"path": path,
|
||||
"bytes_written": bytes_written
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied writing to {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error writing file: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
@@ -1,12 +1,14 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
import json
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import requests
|
||||
from common import const
|
||||
from bot.bot import Bot
|
||||
from bot.openai_compatible_bot import OpenAICompatibleBot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
@@ -18,7 +20,7 @@ from config import conf, load_config
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot, OpenAIImage):
|
||||
class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# set the default api_key
|
||||
@@ -52,6 +54,18 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
if conf_model in [const.O1, const.O1_MINI]: # o1系列模型不支持系统提示词,使用文心模型的session
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or const.O1_MINI)
|
||||
|
||||
def get_api_config(self):
|
||||
"""Get API configuration for OpenAI-compatible base class"""
|
||||
return {
|
||||
'api_key': conf().get("open_ai_api_key"),
|
||||
'api_base': conf().get("open_ai_api_base"),
|
||||
'model': conf().get("model", "gpt-3.5-turbo"),
|
||||
'default_temperature': conf().get("temperature", 0.9),
|
||||
'default_top_p': conf().get("top_p", 1.0),
|
||||
'default_frequency_penalty': conf().get("frequency_penalty", 0.0),
|
||||
'default_presence_penalty': conf().get("presence_penalty", 0.0),
|
||||
}
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
@@ -171,7 +185,6 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
class AzureChatGPTBot(ChatGPTBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -1,222 +0,0 @@
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from curl_cffi import requests
|
||||
from bot.bot import Bot
|
||||
from bot.claude.claude_ai_session import ClaudeAiSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class ClaudeAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.claude_api_cookie = conf().get("claude_api_cookie")
|
||||
self.proxy = conf().get("proxy")
|
||||
self.con_uuid_dic = {}
|
||||
if self.proxy:
|
||||
self.proxies = {
|
||||
"http": self.proxy,
|
||||
"https": self.proxy
|
||||
}
|
||||
else:
|
||||
self.proxies = None
|
||||
self.error = ""
|
||||
self.org_uuid = self.get_organization_id()
|
||||
|
||||
def generate_uuid(self):
|
||||
random_uuid = uuid.uuid4()
|
||||
random_uuid_str = str(random_uuid)
|
||||
formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
|
||||
return formatted_uuid
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
return self._chat(query, context)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, res = self.create_img(query, 0)
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, res)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def get_organization_id(self):
|
||||
url = "https://claude.ai/api/organizations"
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}'
|
||||
}
|
||||
try:
|
||||
response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
|
||||
res = json.loads(response.text)
|
||||
uuid = res[0]['uuid']
|
||||
except:
|
||||
if "App unavailable" in response.text:
|
||||
logger.error("IP error: The IP is not allowed to be used on Claude")
|
||||
self.error = "ip所在地区不被claude支持"
|
||||
elif "Invalid authorization" in response.text:
|
||||
logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
|
||||
self.error = "无法通过claude身份验证,请检查cookie"
|
||||
return None
|
||||
return uuid
|
||||
|
||||
def conversation_share_check(self,session_id):
|
||||
if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
|
||||
con_uuid = conf().get("claude_uuid")
|
||||
return con_uuid
|
||||
if session_id not in self.con_uuid_dic:
|
||||
self.con_uuid_dic[session_id] = self.generate_uuid()
|
||||
self.create_new_chat(self.con_uuid_dic[session_id])
|
||||
return self.con_uuid_dic[session_id]
|
||||
|
||||
def check_cookie(self):
|
||||
flag = self.get_organization_id()
|
||||
return flag
|
||||
|
||||
def create_new_chat(self, con_uuid):
|
||||
"""
|
||||
新建claude对话实体
|
||||
:param con_uuid: 对话id
|
||||
:return:
|
||||
"""
|
||||
url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
|
||||
payload = json.dumps({"uuid": con_uuid, "name": ""})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': self.claude_api_cookie,
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
|
||||
# Returns JSON of the newly created conversation information
|
||||
return response.json()
|
||||
|
||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
||||
"""
|
||||
发起对话请求
|
||||
:param query: 请求提示词
|
||||
:param context: 对话上下文
|
||||
:param retry_count: 当前递归重试次数
|
||||
:return: 回复
|
||||
"""
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[CLAUDEAI] failed after maximum number of retry times")
|
||||
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
||||
|
||||
try:
|
||||
session_id = context["session_id"]
|
||||
if self.org_uuid is None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
con_uuid = self.conversation_share_check(session_id)
|
||||
|
||||
model = conf().get("model") or "gpt-3.5-turbo"
|
||||
# remove system message
|
||||
if session.messages[0].get("role") == "system":
|
||||
if model == "wenxin" or model == "claude":
|
||||
session.messages.pop(0)
|
||||
logger.info(f"[CLAUDEAI] query={query}")
|
||||
|
||||
# do http request
|
||||
base_url = "https://claude.ai"
|
||||
payload = json.dumps({
|
||||
"completion": {
|
||||
"prompt": f"{query}",
|
||||
"timezone": "Asia/Kolkata",
|
||||
"model": "claude-2"
|
||||
},
|
||||
"organization_uuid": f"{self.org_uuid}",
|
||||
"conversation_uuid": f"{con_uuid}",
|
||||
"text": f"{query}",
|
||||
"attachments": []
|
||||
})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept': 'text/event-stream, text/event-stream',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
|
||||
res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
|
||||
if res.status_code == 200 or "pemission" in res.text:
|
||||
# execute success
|
||||
decoded_data = res.content.decode("utf-8")
|
||||
decoded_data = re.sub('\n+', '\n', decoded_data).strip()
|
||||
data_strings = decoded_data.split('\n')
|
||||
completions = []
|
||||
for data_string in data_strings:
|
||||
json_str = data_string[6:].strip()
|
||||
data = json.loads(json_str)
|
||||
if 'completion' in data:
|
||||
completions.append(data['completion'])
|
||||
|
||||
reply_content = ''.join(completions)
|
||||
|
||||
if "rate limi" in reply_content:
|
||||
logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
|
||||
return Reply(ReplyType.ERROR, "对话达到系统速率限制,与cladue同步,请进入官网查看解除限制时间")
|
||||
logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
|
||||
self.sessions.session_reply(reply_content, session_id, 100)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
else:
|
||||
flag = self.check_cookie()
|
||||
if flag == None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
@@ -1,9 +0,0 @@
|
||||
from bot.session_manager import Session
|
||||
|
||||
|
||||
class ClaudeAiSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="claude"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# claude逆向不支持role prompt
|
||||
# self.reset()
|
||||
@@ -1,21 +1,28 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import anthropic
|
||||
import requests
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
# Optional OpenAI image support
|
||||
try:
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
_openai_image_available = True
|
||||
except Exception as e:
|
||||
logger.warning(f"OpenAI image support not available: {e}")
|
||||
_openai_image_available = False
|
||||
OpenAIImage = object # Fallback to object
|
||||
|
||||
user_session = dict()
|
||||
|
||||
|
||||
@@ -23,13 +30,9 @@ user_session = dict()
|
||||
class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
proxy = conf().get("proxy", None)
|
||||
base_url = conf().get("open_ai_api_base", None) # 复用"open_ai_api_base"参数作为base_url
|
||||
self.claudeClient = anthropic.Anthropic(
|
||||
api_key=conf().get("claude_api_key"),
|
||||
proxies=proxy if proxy else None,
|
||||
base_url=base_url if base_url else None
|
||||
)
|
||||
self.api_key = conf().get("claude_api_key")
|
||||
self.api_base = conf().get("open_ai_api_base") or "https://api.anthropic.com/v1"
|
||||
self.proxy = conf().get("proxy", None)
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "text-davinci-003")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
@@ -73,39 +76,104 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
|
||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0, tools=None):
|
||||
try:
|
||||
actual_model = self._model_mapping(conf().get("model"))
|
||||
response = self.claudeClient.messages.create(
|
||||
model=actual_model,
|
||||
max_tokens=4096,
|
||||
system=conf().get("character_desc", ""),
|
||||
messages=session.messages
|
||||
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
|
||||
# Extract system prompt if present and prepare Claude-compatible messages
|
||||
system_prompt = conf().get("character_desc", "")
|
||||
claude_messages = []
|
||||
|
||||
for msg in session.messages:
|
||||
if msg.get("role") == "system":
|
||||
system_prompt = msg["content"]
|
||||
else:
|
||||
claude_messages.append(msg)
|
||||
|
||||
# Prepare request data
|
||||
data = {
|
||||
"model": actual_model,
|
||||
"messages": claude_messages,
|
||||
"max_tokens": self._get_max_tokens(actual_model)
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
data["system"] = system_prompt
|
||||
|
||||
if tools:
|
||||
data["tools"] = tools
|
||||
|
||||
# Make HTTP request
|
||||
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
||||
response = requests.post(
|
||||
f"{self.api_base}/messages",
|
||||
headers=headers,
|
||||
json=data,
|
||||
proxies=proxies
|
||||
)
|
||||
# response = openai.Completion.create(prompt=str(session), **self.args)
|
||||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
|
||||
total_tokens = response.usage.input_tokens+response.usage.output_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"API request failed: {response.status_code} - {response.text}")
|
||||
|
||||
claude_response = response.json()
|
||||
# Handle response content and tool calls
|
||||
res_content = ""
|
||||
tool_calls = []
|
||||
|
||||
content_blocks = claude_response.get("content", [])
|
||||
for block in content_blocks:
|
||||
if block.get("type") == "text":
|
||||
res_content += block.get("text", "")
|
||||
elif block.get("type") == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": block.get("id", ""),
|
||||
"name": block.get("name", ""),
|
||||
"arguments": block.get("input", {})
|
||||
})
|
||||
|
||||
res_content = res_content.strip().replace("<|endoftext|>", "")
|
||||
usage = claude_response.get("usage", {})
|
||||
total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
|
||||
completion_tokens = usage.get("output_tokens", 0)
|
||||
|
||||
logger.info("[CLAUDE_API] reply={}".format(res_content))
|
||||
return {
|
||||
if tool_calls:
|
||||
logger.info("[CLAUDE_API] tool_calls={}".format(tool_calls))
|
||||
|
||||
result = {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
result["tool_calls"] = tool_calls
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"total_tokens": 0, "completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
|
||||
# Handle different types of errors
|
||||
error_str = str(e).lower()
|
||||
if "rate" in error_str or "limit" in error_str:
|
||||
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
elif "timeout" in error_str:
|
||||
logger.warn("[CLAUDE_API] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
elif "connection" in error_str or "network" in error_str:
|
||||
logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
@@ -116,7 +184,7 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
return self.reply_text(session, retry_count + 1, tools)
|
||||
else:
|
||||
return result
|
||||
|
||||
@@ -130,3 +198,288 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
elif model == "claude-3.5-sonnet":
|
||||
return const.CLAUDE_35_SONNET
|
||||
return model
|
||||
|
||||
def _get_max_tokens(self, model: str) -> int:
|
||||
"""
|
||||
Get max_tokens for the model.
|
||||
Reference from pi-mono:
|
||||
- Claude 3.5/3.7: 8192
|
||||
- Claude 3 Opus: 4096
|
||||
- Default: 8192
|
||||
"""
|
||||
if model and (model.startswith("claude-3-5") or model.startswith("claude-3-7")):
|
||||
return 8192
|
||||
elif model and model.startswith("claude-3") and "opus" in model:
|
||||
return 4096
|
||||
elif model and (model.startswith("claude-sonnet-4") or model.startswith("claude-opus-4")):
|
||||
return 64000
|
||||
return 8192
|
||||
|
||||
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
"""
|
||||
Call Claude API with tool support for agent integration
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
tools: List of tool definitions
|
||||
stream: Whether to use streaming
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Formatted response compatible with OpenAI format or generator for streaming
|
||||
"""
|
||||
actual_model = self._model_mapping(conf().get("model"))
|
||||
|
||||
# Extract system prompt from messages if present
|
||||
system_prompt = kwargs.get("system", conf().get("character_desc", ""))
|
||||
claude_messages = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
system_prompt = msg["content"]
|
||||
else:
|
||||
claude_messages.append(msg)
|
||||
|
||||
request_params = {
|
||||
"model": actual_model,
|
||||
"max_tokens": kwargs.get("max_tokens", self._get_max_tokens(actual_model)),
|
||||
"messages": claude_messages,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
request_params["system"] = system_prompt
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
try:
|
||||
if stream:
|
||||
return self._handle_stream_response(request_params)
|
||||
else:
|
||||
return self._handle_sync_response(request_params)
|
||||
except Exception as e:
|
||||
logger.error(f"Claude API call error: {e}")
|
||||
if stream:
|
||||
# Return error generator for stream
|
||||
def error_generator():
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
return error_generator()
|
||||
else:
|
||||
# Return error response for sync
|
||||
return {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _handle_sync_response(self, request_params):
|
||||
"""Handle synchronous Claude API response"""
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
|
||||
# Make HTTP request
|
||||
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
||||
response = requests.post(
|
||||
f"{self.api_base}/messages",
|
||||
headers=headers,
|
||||
json=request_params,
|
||||
proxies=proxies
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"API request failed: {response.status_code} - {response.text}")
|
||||
|
||||
claude_response = response.json()
|
||||
|
||||
# Extract content blocks
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
|
||||
content_blocks = claude_response.get("content", [])
|
||||
for block in content_blocks:
|
||||
if block.get("type") == "text":
|
||||
text_content += block.get("text", "")
|
||||
elif block.get("type") == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": block.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name", ""),
|
||||
"arguments": json.dumps(block.get("input", {}))
|
||||
}
|
||||
})
|
||||
|
||||
# Build message in OpenAI format
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": text_content
|
||||
}
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
|
||||
# Format response to match OpenAI structure
|
||||
usage = claude_response.get("usage", {})
|
||||
formatted_response = {
|
||||
"id": claude_response.get("id", ""),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": claude_response.get("model", request_params["model"]),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": claude_response.get("stop_reason", "stop")
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
|
||||
}
|
||||
}
|
||||
|
||||
return formatted_response
|
||||
|
||||
def _handle_stream_response(self, request_params):
|
||||
"""Handle streaming Claude API response using HTTP requests"""
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
|
||||
# Add stream parameter
|
||||
request_params["stream"] = True
|
||||
|
||||
# Track tool use state
|
||||
tool_uses_map = {} # {index: {id, name, input}}
|
||||
current_tool_use_index = -1
|
||||
|
||||
try:
|
||||
# Make streaming HTTP request
|
||||
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
||||
response = requests.post(
|
||||
f"{self.api_base}/messages",
|
||||
headers=headers,
|
||||
json=request_params,
|
||||
proxies=proxies,
|
||||
stream=True
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
try:
|
||||
error_data = json.loads(error_text)
|
||||
error_msg = error_data.get("error", {}).get("message", error_text)
|
||||
except:
|
||||
error_msg = error_text or "Unknown error"
|
||||
|
||||
yield {
|
||||
"error": True,
|
||||
"status_code": response.status_code,
|
||||
"message": error_msg
|
||||
}
|
||||
return
|
||||
|
||||
# Process streaming response
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode('utf-8')
|
||||
if line.startswith('data: '):
|
||||
line = line[6:] # Remove 'data: ' prefix
|
||||
if line == '[DONE]':
|
||||
break
|
||||
try:
|
||||
event = json.loads(line)
|
||||
event_type = event.get("type")
|
||||
|
||||
if event_type == "content_block_start":
|
||||
# New content block
|
||||
block = event.get("content_block", {})
|
||||
if block.get("type") == "tool_use":
|
||||
current_tool_use_index = event.get("index", 0)
|
||||
tool_uses_map[current_tool_use_index] = {
|
||||
"id": block.get("id", ""),
|
||||
"name": block.get("name", ""),
|
||||
"input": ""
|
||||
}
|
||||
|
||||
elif event_type == "content_block_delta":
|
||||
delta = event.get("delta", {})
|
||||
delta_type = delta.get("type")
|
||||
|
||||
if delta_type == "text_delta":
|
||||
# Text content
|
||||
content = delta.get("text", "")
|
||||
yield {
|
||||
"id": event.get("id", ""),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": request_params["model"],
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": content},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
|
||||
elif delta_type == "input_json_delta":
|
||||
# Tool input accumulation
|
||||
if current_tool_use_index >= 0:
|
||||
tool_uses_map[current_tool_use_index]["input"] += delta.get("partial_json", "")
|
||||
|
||||
elif event_type == "message_delta":
|
||||
# Message complete - yield tool calls if any
|
||||
if tool_uses_map:
|
||||
for idx in sorted(tool_uses_map.keys()):
|
||||
tool_data = tool_uses_map[idx]
|
||||
yield {
|
||||
"id": event.get("id", ""),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": request_params["model"],
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"index": idx,
|
||||
"id": tool_data["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_data["name"],
|
||||
"arguments": tool_data["input"]
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Claude streaming request error: {e}")
|
||||
yield {
|
||||
"error": True,
|
||||
"message": f"Connection error: {str(e)}",
|
||||
"status_code": 0
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Claude streaming error: {e}")
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
@@ -6,6 +6,9 @@ Google gemini bot
|
||||
"""
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
from bot.bot import Bot
|
||||
import google.generativeai as genai
|
||||
from bot.session_manager import SessionManager
|
||||
@@ -29,6 +32,19 @@ class GoogleGeminiBot(Bot):
|
||||
self.model = conf().get("model") or "gemini-pro"
|
||||
if self.model == "gemini":
|
||||
self.model = "gemini-pro"
|
||||
|
||||
# 支持自定义API base地址,复用open_ai_api_base配置
|
||||
self.api_base = conf().get("open_ai_api_base", "").strip()
|
||||
if self.api_base:
|
||||
# 移除末尾的斜杠
|
||||
self.api_base = self.api_base.rstrip('/')
|
||||
# 如果配置的是OpenAI的地址,则使用默认的Gemini地址
|
||||
if "api.openai.com" in self.api_base or not self.api_base:
|
||||
self.api_base = "https://generativelanguage.googleapis.com"
|
||||
logger.info(f"[Gemini] Using custom API base: {self.api_base}")
|
||||
else:
|
||||
self.api_base = "https://generativelanguage.googleapis.com"
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
try:
|
||||
if context.type != ContextType.TEXT:
|
||||
@@ -113,3 +129,557 @@ class GoogleGeminiBot(Bot):
|
||||
elif turn == "assistant":
|
||||
turn = "user"
|
||||
return res
|
||||
|
||||
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
"""
|
||||
Call Gemini API with tool support using REST API (following official docs)
|
||||
|
||||
Args:
|
||||
messages: List of messages (OpenAI format)
|
||||
tools: List of tool definitions (OpenAI/Claude format)
|
||||
stream: Whether to use streaming
|
||||
**kwargs: Additional parameters (system, max_tokens, temperature, etc.)
|
||||
|
||||
Returns:
|
||||
Formatted response compatible with OpenAI format or generator for streaming
|
||||
"""
|
||||
try:
|
||||
model_name = kwargs.get("model", self.model or "gemini-1.5-flash")
|
||||
|
||||
# Build REST API payload
|
||||
payload = {"contents": []}
|
||||
|
||||
# Extract and set system instruction
|
||||
system_prompt = kwargs.get("system", "")
|
||||
if not system_prompt:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
system_prompt = msg["content"]
|
||||
break
|
||||
|
||||
if system_prompt:
|
||||
payload["system_instruction"] = {
|
||||
"parts": [{"text": system_prompt}]
|
||||
}
|
||||
|
||||
# Convert messages to Gemini format
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
continue
|
||||
|
||||
# Convert role
|
||||
gemini_role = "user" if role in ["user", "tool"] else "model"
|
||||
|
||||
# Handle different content formats
|
||||
parts = []
|
||||
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
parts.append({"text": content})
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List of content blocks (Claude format)
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
if isinstance(block, str):
|
||||
parts.append({"text": block})
|
||||
continue
|
||||
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == "text":
|
||||
# Text block
|
||||
parts.append({"text": block.get("text", "")})
|
||||
|
||||
elif block_type == "tool_result":
|
||||
# Convert Claude tool_result to Gemini functionResponse
|
||||
tool_use_id = block.get("tool_use_id")
|
||||
tool_content = block.get("content", "")
|
||||
|
||||
# Try to parse tool content as JSON
|
||||
try:
|
||||
if isinstance(tool_content, str):
|
||||
tool_result_data = json.loads(tool_content)
|
||||
else:
|
||||
tool_result_data = tool_content
|
||||
except:
|
||||
tool_result_data = {"result": tool_content}
|
||||
|
||||
# Find the tool name from previous messages
|
||||
# Look for the corresponding tool_call in model's message
|
||||
tool_name = None
|
||||
for prev_msg in reversed(messages):
|
||||
if prev_msg.get("role") == "assistant":
|
||||
prev_content = prev_msg.get("content", [])
|
||||
if isinstance(prev_content, list):
|
||||
for prev_block in prev_content:
|
||||
if isinstance(prev_block, dict) and prev_block.get("type") == "tool_use":
|
||||
if prev_block.get("id") == tool_use_id:
|
||||
tool_name = prev_block.get("name")
|
||||
break
|
||||
if tool_name:
|
||||
break
|
||||
|
||||
# Gemini functionResponse format
|
||||
parts.append({
|
||||
"functionResponse": {
|
||||
"name": tool_name or "unknown",
|
||||
"response": tool_result_data
|
||||
}
|
||||
})
|
||||
|
||||
elif "text" in block:
|
||||
# Generic text field
|
||||
parts.append({"text": block["text"]})
|
||||
|
||||
if parts:
|
||||
payload["contents"].append({
|
||||
"role": gemini_role,
|
||||
"parts": parts
|
||||
})
|
||||
|
||||
# Generation config
|
||||
gen_config = {}
|
||||
if kwargs.get("temperature") is not None:
|
||||
gen_config["temperature"] = kwargs["temperature"]
|
||||
if kwargs.get("max_tokens"):
|
||||
gen_config["maxOutputTokens"] = kwargs["max_tokens"]
|
||||
if gen_config:
|
||||
payload["generationConfig"] = gen_config
|
||||
|
||||
# Convert tools to Gemini format (REST API style)
|
||||
if tools:
|
||||
gemini_tools = self._convert_tools_to_gemini_rest_format(tools)
|
||||
if gemini_tools:
|
||||
payload["tools"] = gemini_tools
|
||||
logger.info(f"[Gemini] Added {len(tools)} tools to request")
|
||||
|
||||
# Make REST API call
|
||||
base_url = f"{self.api_base}/v1beta"
|
||||
endpoint = f"{base_url}/models/{model_name}:generateContent"
|
||||
if stream:
|
||||
endpoint = f"{base_url}/models/{model_name}:streamGenerateContent?alt=sse"
|
||||
|
||||
headers = {
|
||||
"x-goog-api-key": self.api_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
logger.debug(f"[Gemini] REST API call: {endpoint}")
|
||||
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=stream,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
# Check HTTP status for stream mode (for non-stream, it's checked in handler)
|
||||
if stream and response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"[Gemini] API error ({response.status_code}): {error_text}")
|
||||
def error_generator():
|
||||
yield {
|
||||
"error": True,
|
||||
"message": f"Gemini API error: {error_text}",
|
||||
"status_code": response.status_code
|
||||
}
|
||||
return error_generator()
|
||||
|
||||
if stream:
|
||||
return self._handle_gemini_rest_stream_response(response, model_name)
|
||||
else:
|
||||
return self._handle_gemini_rest_sync_response(response, model_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Gemini] call_with_tools error: {e}", exc_info=True)
|
||||
error_msg = str(e) # Capture error message before creating generator
|
||||
if stream:
|
||||
def error_generator():
|
||||
yield {
|
||||
"error": True,
|
||||
"message": error_msg,
|
||||
"status_code": 500
|
||||
}
|
||||
return error_generator()
|
||||
else:
|
||||
return {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _convert_tools_to_gemini_rest_format(self, tools_list):
|
||||
"""
|
||||
Convert tools to Gemini REST API format
|
||||
|
||||
Handles both OpenAI and Claude/Agent formats.
|
||||
Returns: [{"functionDeclarations": [...]}]
|
||||
"""
|
||||
function_declarations = []
|
||||
|
||||
for tool in tools_list:
|
||||
# Extract name, description, and parameters based on format
|
||||
if tool.get("type") == "function":
|
||||
# OpenAI format: {"type": "function", "function": {...}}
|
||||
func = tool.get("function", {})
|
||||
name = func.get("name")
|
||||
description = func.get("description", "")
|
||||
parameters = func.get("parameters", {})
|
||||
else:
|
||||
# Claude/Agent format: {"name": "...", "description": "...", "input_schema": {...}}
|
||||
name = tool.get("name")
|
||||
description = tool.get("description", "")
|
||||
parameters = tool.get("input_schema", {})
|
||||
|
||||
if not name:
|
||||
logger.warning(f"[Gemini] Skipping tool without name: {tool}")
|
||||
continue
|
||||
|
||||
logger.debug(f"[Gemini] Converting tool: {name}")
|
||||
|
||||
function_declarations.append({
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters
|
||||
})
|
||||
|
||||
# All functionDeclarations must be in a single tools object (per Gemini REST API spec)
|
||||
return [{
|
||||
"functionDeclarations": function_declarations
|
||||
}] if function_declarations else []
|
||||
|
||||
def _handle_gemini_rest_sync_response(self, response, model_name):
|
||||
"""Handle Gemini REST API sync response and convert to OpenAI format"""
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"[Gemini] API error ({response.status_code}): {error_text}")
|
||||
return {
|
||||
"error": True,
|
||||
"message": f"Gemini API error: {error_text}",
|
||||
"status_code": response.status_code
|
||||
}
|
||||
|
||||
data = response.json()
|
||||
logger.debug(f"[Gemini] Response data: {json.dumps(data, ensure_ascii=False)[:500]}")
|
||||
|
||||
# Extract from Gemini response format
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates:
|
||||
logger.warning("[Gemini] No candidates in response")
|
||||
return {
|
||||
"error": True,
|
||||
"message": "No candidates in response",
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
candidate = candidates[0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
logger.debug(f"[Gemini] Candidate parts count: {len(parts)}")
|
||||
|
||||
# Extract text and function calls
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
|
||||
for part in parts:
|
||||
# Check for text
|
||||
if "text" in part:
|
||||
text_content += part["text"]
|
||||
logger.debug(f"[Gemini] Text part: {part['text'][:100]}...")
|
||||
|
||||
# Check for functionCall (per REST API docs)
|
||||
if "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
logger.info(f"[Gemini] Function call detected: {fc.get('name')}")
|
||||
|
||||
tool_calls.append({
|
||||
"id": f"call_{int(time.time() * 1000000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": fc.get("name"),
|
||||
"arguments": json.dumps(fc.get("args", {}))
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(f"[Gemini] Response: text={len(text_content)} chars, tool_calls={len(tool_calls)}")
|
||||
|
||||
# Build OpenAI format response
|
||||
message_dict = {
|
||||
"role": "assistant",
|
||||
"content": text_content or None
|
||||
}
|
||||
if tool_calls:
|
||||
message_dict["tool_calls"] = tool_calls
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{time.time()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": message_dict,
|
||||
"finish_reason": "tool_calls" if tool_calls else "stop"
|
||||
}],
|
||||
"usage": data.get("usageMetadata", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Gemini] sync response error: {e}", exc_info=True)
|
||||
return {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _handle_gemini_rest_stream_response(self, response, model_name):
|
||||
"""Handle Gemini REST API stream response"""
|
||||
try:
|
||||
all_tool_calls = []
|
||||
has_sent_tool_calls = False
|
||||
has_content = False # Track if any content was sent
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
line = line.decode('utf-8')
|
||||
|
||||
# Skip SSE prefixes
|
||||
if line.startswith('data: '):
|
||||
line = line[6:]
|
||||
|
||||
if not line or line == '[DONE]':
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(line)
|
||||
logger.debug(f"[Gemini] Stream chunk: {json.dumps(chunk_data, ensure_ascii=False)[:200]}")
|
||||
|
||||
candidates = chunk_data.get("candidates", [])
|
||||
if not candidates:
|
||||
logger.debug("[Gemini] No candidates in chunk")
|
||||
continue
|
||||
|
||||
candidate = candidates[0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if not parts:
|
||||
logger.debug("[Gemini] No parts in candidate content")
|
||||
|
||||
# Stream text content
|
||||
for part in parts:
|
||||
if "text" in part and part["text"]:
|
||||
has_content = True
|
||||
logger.debug(f"[Gemini] Streaming text: {part['text'][:50]}...")
|
||||
yield {
|
||||
"id": f"chatcmpl-{time.time()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": part["text"]},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
|
||||
# Collect function calls
|
||||
if "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
logger.debug(f"[Gemini] Function call detected: {fc.get('name')}")
|
||||
all_tool_calls.append({
|
||||
"index": len(all_tool_calls), # Add index to differentiate multiple tool calls
|
||||
"id": f"call_{int(time.time() * 1000000)}_{len(all_tool_calls)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": fc.get("name"),
|
||||
"arguments": json.dumps(fc.get("args", {}))
|
||||
}
|
||||
})
|
||||
|
||||
except json.JSONDecodeError as je:
|
||||
logger.debug(f"[Gemini] JSON decode error: {je}")
|
||||
continue
|
||||
|
||||
# Send tool calls if any were collected
|
||||
if all_tool_calls and not has_sent_tool_calls:
|
||||
logger.info(f"[Gemini] Stream detected {len(all_tool_calls)} tool calls")
|
||||
yield {
|
||||
"id": f"chatcmpl-{time.time()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"tool_calls": all_tool_calls},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
has_sent_tool_calls = True
|
||||
|
||||
# Log summary
|
||||
logger.info(f"[Gemini] Stream complete: has_content={has_content}, tool_calls={len(all_tool_calls)}")
|
||||
|
||||
# Final chunk
|
||||
yield {
|
||||
"id": f"chatcmpl-{time.time()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "tool_calls" if all_tool_calls else "stop"
|
||||
}]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Gemini] stream response error: {e}", exc_info=True)
|
||||
error_msg = str(e)
|
||||
yield {
|
||||
"error": True,
|
||||
"message": error_msg,
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _convert_tools_to_gemini_format(self, openai_tools):
|
||||
"""Convert OpenAI tool format to Gemini function declarations"""
|
||||
import google.generativeai as genai
|
||||
|
||||
gemini_functions = []
|
||||
for tool in openai_tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool.get("function", {})
|
||||
gemini_functions.append(
|
||||
genai.protos.FunctionDeclaration(
|
||||
name=func.get("name"),
|
||||
description=func.get("description", ""),
|
||||
parameters=func.get("parameters", {})
|
||||
)
|
||||
)
|
||||
|
||||
if gemini_functions:
|
||||
return [genai.protos.Tool(function_declarations=gemini_functions)]
|
||||
return None
|
||||
|
||||
def _handle_gemini_sync_response(self, model, messages, request_params, model_name):
|
||||
"""Handle synchronous Gemini API response"""
|
||||
import json
|
||||
|
||||
response = model.generate_content(messages, **request_params)
|
||||
|
||||
# Extract text content and function calls
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
|
||||
if response.candidates and response.candidates[0].content:
|
||||
for part in response.candidates[0].content.parts:
|
||||
if hasattr(part, 'text') and part.text:
|
||||
text_content += part.text
|
||||
elif hasattr(part, 'function_call') and part.function_call:
|
||||
# Convert Gemini function call to OpenAI format
|
||||
func_call = part.function_call
|
||||
tool_calls.append({
|
||||
"id": f"call_{hash(func_call.name)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_call.name,
|
||||
"arguments": json.dumps(dict(func_call.args))
|
||||
}
|
||||
})
|
||||
|
||||
# Build message in OpenAI format
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": text_content
|
||||
}
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
|
||||
# Format response to match OpenAI structure
|
||||
formatted_response = {
|
||||
"id": f"gemini_{int(time.time())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": "stop" if not tool_calls else "tool_calls"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0, # Gemini doesn't provide token counts in the same way
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"[Gemini] call_with_tools reply, model={model_name}")
|
||||
return formatted_response
|
||||
|
||||
def _handle_gemini_stream_response(self, model, messages, request_params, model_name):
|
||||
"""Handle streaming Gemini API response"""
|
||||
import json
|
||||
|
||||
try:
|
||||
response_stream = model.generate_content(messages, stream=True, **request_params)
|
||||
|
||||
for chunk in response_stream:
|
||||
if chunk.candidates and chunk.candidates[0].content:
|
||||
for part in chunk.candidates[0].content.parts:
|
||||
if hasattr(part, 'text') and part.text:
|
||||
# Text content
|
||||
yield {
|
||||
"id": f"gemini_{int(time.time())}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": part.text},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
elif hasattr(part, 'function_call') and part.function_call:
|
||||
# Function call
|
||||
func_call = part.function_call
|
||||
yield {
|
||||
"id": f"gemini_{int(time.time())}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"index": 0,
|
||||
"id": f"call_{hash(func_call.name)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_call.name,
|
||||
"arguments": json.dumps(dict(func_call.args))
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Gemini] stream response error: {e}")
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import time
|
||||
import requests
|
||||
import config
|
||||
from bot.bot import Bot
|
||||
from bot.openai_compatible_bot import OpenAICompatibleBot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
@@ -17,7 +18,7 @@ from common import memory, utils
|
||||
import base64
|
||||
import os
|
||||
|
||||
class LinkAIBot(Bot):
|
||||
class LinkAIBot(Bot, OpenAICompatibleBot):
|
||||
# authentication failed
|
||||
AUTH_FAILED_CODE = 401
|
||||
NO_QUOTA_CODE = 406
|
||||
@@ -26,6 +27,18 @@ class LinkAIBot(Bot):
|
||||
super().__init__()
|
||||
self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.args = {}
|
||||
|
||||
def get_api_config(self):
|
||||
"""Get API configuration for OpenAI-compatible base class"""
|
||||
return {
|
||||
'api_key': conf().get("open_ai_api_key"), # LinkAI uses OpenAI-compatible key
|
||||
'api_base': conf().get("open_ai_api_base", "https://api.link-ai.tech/v1"),
|
||||
'model': conf().get("model", "gpt-3.5-turbo"),
|
||||
'default_temperature': conf().get("temperature", 0.9),
|
||||
'default_top_p': conf().get("top_p", 1.0),
|
||||
'default_frequency_penalty': conf().get("frequency_penalty", 0.0),
|
||||
'default_presence_penalty': conf().get("presence_penalty", 0.0),
|
||||
}
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
@@ -473,3 +486,150 @@ class LinkAISession(ChatGPTSession):
|
||||
self.messages.pop(i - 1)
|
||||
return self.calc_tokens()
|
||||
return cur_tokens
|
||||
|
||||
|
||||
# Add call_with_tools method to LinkAIBot class
|
||||
def _linkai_call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
"""
|
||||
Call LinkAI API with tool support for agent integration
|
||||
LinkAI is fully compatible with OpenAI's tool calling format
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
tools: List of tool definitions (OpenAI format)
|
||||
stream: Whether to use streaming
|
||||
**kwargs: Additional parameters (max_tokens, temperature, etc.)
|
||||
|
||||
Returns:
|
||||
Formatted response in OpenAI format or generator for streaming
|
||||
"""
|
||||
try:
|
||||
# Build request parameters (LinkAI uses OpenAI-compatible format)
|
||||
body = {
|
||||
"messages": messages,
|
||||
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
|
||||
"temperature": kwargs.get("temperature", conf().get("temperature", 0.9)),
|
||||
"top_p": kwargs.get("top_p", conf().get("top_p", 1)),
|
||||
"frequency_penalty": kwargs.get("frequency_penalty", conf().get("frequency_penalty", 0.0)),
|
||||
"presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)),
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# Add max_tokens if specified
|
||||
if kwargs.get("max_tokens"):
|
||||
body["max_tokens"] = kwargs["max_tokens"]
|
||||
|
||||
# Add app_code if provided
|
||||
app_code = kwargs.get("app_code", conf().get("linkai_app_code"))
|
||||
if app_code:
|
||||
body["app_code"] = app_code
|
||||
|
||||
# Add tools if provided (OpenAI-compatible format)
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
# Prepare headers
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
|
||||
if stream:
|
||||
return self._handle_linkai_stream_response(base_url, headers, body)
|
||||
else:
|
||||
return self._handle_linkai_sync_response(base_url, headers, body)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LinkAI] call_with_tools error: {e}")
|
||||
if stream:
|
||||
def error_generator():
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
return error_generator()
|
||||
else:
|
||||
return {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _handle_linkai_sync_response(self, base_url, headers, body):
|
||||
"""Handle synchronous LinkAI API response"""
|
||||
try:
|
||||
res = requests.post(
|
||||
url=base_url + "/v1/chat/completions",
|
||||
json=body,
|
||||
headers=headers,
|
||||
timeout=conf().get("request_timeout", 180)
|
||||
)
|
||||
|
||||
if res.status_code == 200:
|
||||
response = res.json()
|
||||
logger.info(f"[LinkAI] call_with_tools reply, model={response.get('model')}, "
|
||||
f"total_tokens={response.get('usage', {}).get('total_tokens', 0)}")
|
||||
|
||||
# LinkAI response is already in OpenAI-compatible format
|
||||
return response
|
||||
else:
|
||||
error_data = res.json()
|
||||
error_msg = error_data.get("error", {}).get("message", "Unknown error")
|
||||
raise Exception(f"LinkAI API error: {res.status_code} - {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LinkAI] sync response error: {e}")
|
||||
raise
|
||||
|
||||
def _handle_linkai_stream_response(self, base_url, headers, body):
|
||||
"""Handle streaming LinkAI API response"""
|
||||
try:
|
||||
res = requests.post(
|
||||
url=base_url + "/v1/chat/completions",
|
||||
json=body,
|
||||
headers=headers,
|
||||
timeout=conf().get("request_timeout", 180),
|
||||
stream=True
|
||||
)
|
||||
|
||||
if res.status_code != 200:
|
||||
error_text = res.text
|
||||
try:
|
||||
error_data = json.loads(error_text)
|
||||
error_msg = error_data.get("error", {}).get("message", error_text)
|
||||
except:
|
||||
error_msg = error_text or "Unknown error"
|
||||
|
||||
yield {
|
||||
"error": True,
|
||||
"status_code": res.status_code,
|
||||
"message": error_msg
|
||||
}
|
||||
return
|
||||
|
||||
# Process streaming response (OpenAI-compatible SSE format)
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
line = line.decode('utf-8')
|
||||
if line.startswith('data: '):
|
||||
line = line[6:] # Remove 'data: ' prefix
|
||||
if line == '[DONE]':
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
yield chunk
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LinkAI] stream response error: {e}")
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
# Attach methods to LinkAIBot class
|
||||
LinkAIBot.call_with_tools = _linkai_call_with_tools
|
||||
LinkAIBot._handle_linkai_sync_response = _handle_linkai_sync_response
|
||||
LinkAIBot._handle_linkai_stream_response = _handle_linkai_stream_response
|
||||
|
||||
@@ -6,6 +6,7 @@ import openai
|
||||
import openai.error
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai_compatible_bot import OpenAICompatibleBot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.openai.open_ai_session import OpenAISession
|
||||
from bot.session_manager import SessionManager
|
||||
@@ -18,7 +19,7 @@ user_session = dict()
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class OpenAIBot(Bot, OpenAIImage):
|
||||
class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
@@ -40,6 +41,18 @@ class OpenAIBot(Bot, OpenAIImage):
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
"stop": ["\n\n\n"],
|
||||
}
|
||||
|
||||
def get_api_config(self):
|
||||
"""Get API configuration for OpenAI-compatible base class"""
|
||||
return {
|
||||
'api_key': conf().get("open_ai_api_key"),
|
||||
'api_base': conf().get("open_ai_api_base"),
|
||||
'model': conf().get("model", "text-davinci-003"),
|
||||
'default_temperature': conf().get("temperature", 0.9),
|
||||
'default_top_p': conf().get("top_p", 1.0),
|
||||
'default_frequency_penalty': conf().get("frequency_penalty", 0.0),
|
||||
'default_presence_penalty': conf().get("presence_penalty", 0.0),
|
||||
}
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
@@ -120,3 +133,98 @@ class OpenAIBot(Bot, OpenAIImage):
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
"""
|
||||
Call OpenAI API with tool support for agent integration
|
||||
Note: This bot uses the old Completion API which doesn't support tools.
|
||||
For tool support, use ChatGPTBot instead.
|
||||
|
||||
This method converts to ChatCompletion API when tools are provided.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
tools: List of tool definitions (OpenAI format)
|
||||
stream: Whether to use streaming
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Formatted response in OpenAI format or generator for streaming
|
||||
"""
|
||||
try:
|
||||
# The old Completion API doesn't support tools
|
||||
# We need to use ChatCompletion API instead
|
||||
logger.info("[OPEN_AI] Using ChatCompletion API for tool support")
|
||||
|
||||
# Build request parameters for ChatCompletion
|
||||
request_params = {
|
||||
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", conf().get("temperature", 0.9)),
|
||||
"top_p": kwargs.get("top_p", 1),
|
||||
"frequency_penalty": kwargs.get("frequency_penalty", conf().get("frequency_penalty", 0.0)),
|
||||
"presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)),
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# Add max_tokens if specified
|
||||
if kwargs.get("max_tokens"):
|
||||
request_params["max_tokens"] = kwargs["max_tokens"]
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
request_params["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
# Make API call using ChatCompletion
|
||||
if stream:
|
||||
return self._handle_stream_response(request_params)
|
||||
else:
|
||||
return self._handle_sync_response(request_params)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OPEN_AI] call_with_tools error: {e}")
|
||||
if stream:
|
||||
def error_generator():
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
return error_generator()
|
||||
else:
|
||||
return {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _handle_sync_response(self, request_params):
|
||||
"""Handle synchronous OpenAI ChatCompletion API response"""
|
||||
try:
|
||||
response = openai.ChatCompletion.create(**request_params)
|
||||
|
||||
logger.info(f"[OPEN_AI] call_with_tools reply, model={response.get('model')}, "
|
||||
f"total_tokens={response.get('usage', {}).get('total_tokens', 0)}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OPEN_AI] sync response error: {e}")
|
||||
raise
|
||||
|
||||
def _handle_stream_response(self, request_params):
|
||||
"""Handle streaming OpenAI ChatCompletion API response"""
|
||||
try:
|
||||
stream = openai.ChatCompletion.create(**request_params)
|
||||
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OPEN_AI] stream response error: {e}")
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.openai.openai_compat import RateLimitError
|
||||
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
@@ -30,7 +30,7 @@ class OpenAIImage(object):
|
||||
image_url = response["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
except RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
|
||||
102
bot/openai/openai_compat.py
Normal file
102
bot/openai/openai_compat.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
OpenAI compatibility layer for different versions.
|
||||
|
||||
This module provides a compatibility layer between OpenAI library versions:
|
||||
- OpenAI < 1.0 (old API with openai.error module)
|
||||
- OpenAI >= 1.0 (new API with direct exception imports)
|
||||
"""
|
||||
|
||||
try:
|
||||
# Try new OpenAI >= 1.0 API
|
||||
from openai import (
|
||||
OpenAIError,
|
||||
RateLimitError,
|
||||
APIError,
|
||||
APIConnectionError,
|
||||
AuthenticationError,
|
||||
APITimeoutError,
|
||||
BadRequestError,
|
||||
)
|
||||
|
||||
# Create a mock error module for backward compatibility
|
||||
class ErrorModule:
|
||||
OpenAIError = OpenAIError
|
||||
RateLimitError = RateLimitError
|
||||
APIError = APIError
|
||||
APIConnectionError = APIConnectionError
|
||||
AuthenticationError = AuthenticationError
|
||||
Timeout = APITimeoutError # Renamed in new version
|
||||
InvalidRequestError = BadRequestError # Renamed in new version
|
||||
|
||||
error = ErrorModule()
|
||||
|
||||
# Also export with new names
|
||||
Timeout = APITimeoutError
|
||||
InvalidRequestError = BadRequestError
|
||||
|
||||
except ImportError:
|
||||
# Fall back to old OpenAI < 1.0 API
|
||||
try:
|
||||
import openai.error as error
|
||||
|
||||
# Export individual exceptions for direct import
|
||||
OpenAIError = error.OpenAIError
|
||||
RateLimitError = error.RateLimitError
|
||||
APIError = error.APIError
|
||||
APIConnectionError = error.APIConnectionError
|
||||
AuthenticationError = error.AuthenticationError
|
||||
InvalidRequestError = error.InvalidRequestError
|
||||
Timeout = error.Timeout
|
||||
BadRequestError = error.InvalidRequestError # Alias
|
||||
APITimeoutError = error.Timeout # Alias
|
||||
except (ImportError, AttributeError):
|
||||
# Neither version works, create dummy classes
|
||||
class OpenAIError(Exception):
|
||||
pass
|
||||
|
||||
class RateLimitError(OpenAIError):
|
||||
pass
|
||||
|
||||
class APIError(OpenAIError):
|
||||
pass
|
||||
|
||||
class APIConnectionError(OpenAIError):
|
||||
pass
|
||||
|
||||
class AuthenticationError(OpenAIError):
|
||||
pass
|
||||
|
||||
class InvalidRequestError(OpenAIError):
|
||||
pass
|
||||
|
||||
class Timeout(OpenAIError):
|
||||
pass
|
||||
|
||||
BadRequestError = InvalidRequestError
|
||||
APITimeoutError = Timeout
|
||||
|
||||
# Create error module
|
||||
class ErrorModule:
|
||||
OpenAIError = OpenAIError
|
||||
RateLimitError = RateLimitError
|
||||
APIError = APIError
|
||||
APIConnectionError = APIConnectionError
|
||||
AuthenticationError = AuthenticationError
|
||||
InvalidRequestError = InvalidRequestError
|
||||
Timeout = Timeout
|
||||
|
||||
error = ErrorModule()
|
||||
|
||||
# Export all for easy import
|
||||
__all__ = [
|
||||
'error',
|
||||
'OpenAIError',
|
||||
'RateLimitError',
|
||||
'APIError',
|
||||
'APIConnectionError',
|
||||
'AuthenticationError',
|
||||
'InvalidRequestError',
|
||||
'Timeout',
|
||||
'BadRequestError',
|
||||
'APITimeoutError',
|
||||
]
|
||||
278
bot/openai_compatible_bot.py
Normal file
278
bot/openai_compatible_bot.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# encoding:utf-8
|
||||
|
||||
"""
|
||||
OpenAI-Compatible Bot Base Class
|
||||
|
||||
Provides a common implementation for bots that are compatible with OpenAI's API format.
|
||||
This includes: OpenAI, LinkAI, Azure OpenAI, and many third-party providers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import openai
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class OpenAICompatibleBot:
|
||||
"""
|
||||
Base class for OpenAI-compatible bots.
|
||||
|
||||
Provides common tool calling implementation that can be inherited by:
|
||||
- ChatGPTBot
|
||||
- LinkAIBot
|
||||
- OpenAIBot
|
||||
- AzureChatGPTBot
|
||||
- Other OpenAI-compatible providers
|
||||
|
||||
Subclasses only need to override get_api_config() to provide their specific API settings.
|
||||
"""
|
||||
|
||||
def get_api_config(self):
|
||||
"""
|
||||
Get API configuration for this bot.
|
||||
|
||||
Subclasses should override this to provide their specific config.
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'api_key': str,
|
||||
'api_base': str (optional),
|
||||
'model': str,
|
||||
'default_temperature': float,
|
||||
'default_top_p': float,
|
||||
'default_frequency_penalty': float,
|
||||
'default_presence_penalty': float,
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_api_config()")
|
||||
|
||||
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
"""
|
||||
Call OpenAI-compatible API with tool support for agent integration
|
||||
|
||||
This method handles:
|
||||
1. Format conversion (Claude format → OpenAI format)
|
||||
2. System prompt injection
|
||||
3. API calling with proper configuration
|
||||
4. Error handling
|
||||
|
||||
Args:
|
||||
messages: List of messages (may be in Claude format from agent)
|
||||
tools: List of tool definitions (may be in Claude format from agent)
|
||||
stream: Whether to use streaming
|
||||
**kwargs: Additional parameters (max_tokens, temperature, system, etc.)
|
||||
|
||||
Returns:
|
||||
Formatted response in OpenAI format or generator for streaming
|
||||
"""
|
||||
try:
|
||||
# Get API configuration from subclass
|
||||
api_config = self.get_api_config()
|
||||
|
||||
# Convert messages from Claude format to OpenAI format
|
||||
messages = self._convert_messages_to_openai_format(messages)
|
||||
|
||||
# Convert tools from Claude format to OpenAI format
|
||||
if tools:
|
||||
tools = self._convert_tools_to_openai_format(tools)
|
||||
|
||||
# Handle system prompt (OpenAI uses system message, Claude uses separate parameter)
|
||||
system_prompt = kwargs.get('system')
|
||||
if system_prompt:
|
||||
# Add system message at the beginning if not already present
|
||||
if not messages or messages[0].get('role') != 'system':
|
||||
messages = [{"role": "system", "content": system_prompt}] + messages
|
||||
else:
|
||||
# Replace existing system message
|
||||
messages[0] = {"role": "system", "content": system_prompt}
|
||||
|
||||
# Build request parameters
|
||||
request_params = {
|
||||
"model": kwargs.get("model", api_config.get('model', 'gpt-3.5-turbo')),
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", api_config.get('default_temperature', 0.9)),
|
||||
"top_p": kwargs.get("top_p", api_config.get('default_top_p', 1.0)),
|
||||
"frequency_penalty": kwargs.get("frequency_penalty", api_config.get('default_frequency_penalty', 0.0)),
|
||||
"presence_penalty": kwargs.get("presence_penalty", api_config.get('default_presence_penalty', 0.0)),
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# Add max_tokens if specified
|
||||
if kwargs.get("max_tokens"):
|
||||
request_params["max_tokens"] = kwargs["max_tokens"]
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
request_params["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
# Make API call with proper configuration
|
||||
api_key = api_config.get('api_key')
|
||||
api_base = api_config.get('api_base')
|
||||
|
||||
if stream:
|
||||
return self._handle_stream_response(request_params, api_key, api_base)
|
||||
else:
|
||||
return self._handle_sync_response(request_params, api_key, api_base)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"[{self.__class__.__name__}] call_with_tools error: {error_msg}")
|
||||
if stream:
|
||||
def error_generator():
|
||||
yield {
|
||||
"error": True,
|
||||
"message": error_msg,
|
||||
"status_code": 500
|
||||
}
|
||||
return error_generator()
|
||||
else:
|
||||
return {
|
||||
"error": True,
|
||||
"message": error_msg,
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _handle_sync_response(self, request_params, api_key, api_base):
|
||||
"""Handle synchronous OpenAI API response"""
|
||||
try:
|
||||
# Build kwargs with explicit API configuration
|
||||
kwargs = dict(request_params)
|
||||
if api_key:
|
||||
kwargs["api_key"] = api_key
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
response = openai.ChatCompletion.create(**kwargs)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.__class__.__name__}] sync response error: {e}")
|
||||
return {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _handle_stream_response(self, request_params, api_key, api_base):
|
||||
"""Handle streaming OpenAI API response"""
|
||||
try:
|
||||
# Build kwargs with explicit API configuration
|
||||
kwargs = dict(request_params)
|
||||
if api_key:
|
||||
kwargs["api_key"] = api_key
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
stream = openai.ChatCompletion.create(**kwargs)
|
||||
|
||||
# Stream chunks to caller
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.__class__.__name__}] stream response error: {e}")
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _convert_tools_to_openai_format(self, tools):
|
||||
"""
|
||||
Convert tools from Claude format to OpenAI format
|
||||
|
||||
Claude format: {name, description, input_schema}
|
||||
OpenAI format: {type: "function", function: {name, description, parameters}}
|
||||
"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
# Check if already in OpenAI format
|
||||
if 'type' in tool and tool['type'] == 'function':
|
||||
openai_tools.append(tool)
|
||||
else:
|
||||
# Convert from Claude format
|
||||
openai_tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.get("name"),
|
||||
"description": tool.get("description"),
|
||||
"parameters": tool.get("input_schema", {})
|
||||
}
|
||||
})
|
||||
|
||||
return openai_tools
|
||||
|
||||
def _convert_messages_to_openai_format(self, messages):
|
||||
"""
|
||||
Convert messages from Claude format to OpenAI format
|
||||
|
||||
Claude uses content blocks with types like 'tool_use', 'tool_result'
|
||||
OpenAI uses 'tool_calls' in assistant messages and 'tool' role for results
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
openai_messages = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
# Handle string content (already in correct format)
|
||||
if isinstance(content, str):
|
||||
openai_messages.append(msg)
|
||||
continue
|
||||
|
||||
# Handle list content (Claude format with content blocks)
|
||||
if isinstance(content, list):
|
||||
# Check if this is a tool result message (user role with tool_result blocks)
|
||||
if role == "user" and any(block.get("type") == "tool_result" for block in content):
|
||||
# Convert each tool_result block to a separate tool message
|
||||
for block in content:
|
||||
if block.get("type") == "tool_result":
|
||||
openai_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("tool_use_id"),
|
||||
"content": block.get("content", "")
|
||||
})
|
||||
|
||||
# Check if this is an assistant message with tool_use blocks
|
||||
elif role == "assistant":
|
||||
# Separate text content and tool_use blocks
|
||||
text_parts = []
|
||||
tool_calls = []
|
||||
|
||||
for block in content:
|
||||
if block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif block.get("type") == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": json.dumps(block.get("input", {}))
|
||||
}
|
||||
})
|
||||
|
||||
# Build OpenAI format assistant message
|
||||
openai_msg = {
|
||||
"role": "assistant",
|
||||
"content": " ".join(text_parts) if text_parts else None
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
openai_msg["tool_calls"] = tool_calls
|
||||
|
||||
openai_messages.append(openai_msg)
|
||||
else:
|
||||
# Other list content, keep as is
|
||||
openai_messages.append(msg)
|
||||
else:
|
||||
# Other formats, keep as is
|
||||
openai_messages.append(msg)
|
||||
|
||||
return openai_messages
|
||||
378
bridge/agent_bridge.py
Normal file
378
bridge/agent_bridge.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
Agent Bridge - Integrates Agent system with existing COW bridge
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from agent.protocol import Agent, LLMModel, LLMRequest
|
||||
from bot.openai_compatible_bot import OpenAICompatibleBot
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common import const
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def add_openai_compatible_support(bot_instance):
|
||||
"""
|
||||
Dynamically add OpenAI-compatible tool calling support to a bot instance.
|
||||
|
||||
This allows any bot to gain tool calling capability without modifying its code,
|
||||
as long as it uses OpenAI-compatible API format.
|
||||
"""
|
||||
if hasattr(bot_instance, 'call_with_tools'):
|
||||
# Bot already has tool calling support
|
||||
return bot_instance
|
||||
|
||||
# Create a temporary mixin class that combines the bot with OpenAI compatibility
|
||||
class EnhancedBot(bot_instance.__class__, OpenAICompatibleBot):
|
||||
"""Dynamically enhanced bot with OpenAI-compatible tool calling"""
|
||||
|
||||
def get_api_config(self):
|
||||
"""
|
||||
Infer API config from common configuration patterns.
|
||||
Most OpenAI-compatible bots use similar configuration.
|
||||
"""
|
||||
from config import conf
|
||||
|
||||
return {
|
||||
'api_key': conf().get("open_ai_api_key"),
|
||||
'api_base': conf().get("open_ai_api_base"),
|
||||
'model': conf().get("model", "gpt-3.5-turbo"),
|
||||
'default_temperature': conf().get("temperature", 0.9),
|
||||
'default_top_p': conf().get("top_p", 1.0),
|
||||
'default_frequency_penalty': conf().get("frequency_penalty", 0.0),
|
||||
'default_presence_penalty': conf().get("presence_penalty", 0.0),
|
||||
}
|
||||
|
||||
# Change the bot's class to the enhanced version
|
||||
bot_instance.__class__ = EnhancedBot
|
||||
logger.info(
|
||||
f"[AgentBridge] Enhanced {bot_instance.__class__.__bases__[0].__name__} with OpenAI-compatible tool calling")
|
||||
|
||||
return bot_instance
|
||||
|
||||
|
||||
class AgentLLMModel(LLMModel):
|
||||
"""
|
||||
LLM Model adapter that uses COW's existing bot infrastructure
|
||||
"""
|
||||
|
||||
def __init__(self, bridge: Bridge, bot_type: str = "chat"):
|
||||
# Get model name directly from config
|
||||
from config import conf
|
||||
model_name = conf().get("model", const.GPT_41)
|
||||
super().__init__(model=model_name)
|
||||
self.bridge = bridge
|
||||
self.bot_type = bot_type
|
||||
self._bot = None
|
||||
|
||||
@property
|
||||
def bot(self):
|
||||
"""Lazy load the bot and enhance it with tool calling if needed"""
|
||||
if self._bot is None:
|
||||
self._bot = self.bridge.get_bot(self.bot_type)
|
||||
# Automatically add tool calling support if not present
|
||||
self._bot = add_openai_compatible_support(self._bot)
|
||||
return self._bot
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model using COW's bot infrastructure
|
||||
"""
|
||||
try:
|
||||
# For non-streaming calls, we'll use the existing reply method
|
||||
# This is a simplified implementation
|
||||
if hasattr(self.bot, 'call_with_tools'):
|
||||
# Use tool-enabled call if available
|
||||
kwargs = {
|
||||
'messages': request.messages,
|
||||
'tools': getattr(request, 'tools', None),
|
||||
'stream': False
|
||||
}
|
||||
# Only pass max_tokens if it's explicitly set
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
response = self.bot.call_with_tools(**kwargs)
|
||||
return self._format_response(response)
|
||||
else:
|
||||
# Fallback to regular call
|
||||
# This would need to be implemented based on your specific needs
|
||||
raise NotImplementedError("Regular call not implemented yet")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AgentLLMModel call error: {e}")
|
||||
raise
|
||||
|
||||
def call_stream(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with streaming using COW's bot infrastructure
|
||||
"""
|
||||
try:
|
||||
if hasattr(self.bot, 'call_with_tools'):
|
||||
# Use tool-enabled streaming call if available
|
||||
# Ensure max_tokens is an integer, use default if None
|
||||
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
|
||||
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
|
||||
# Build kwargs for call_with_tools
|
||||
kwargs = {
|
||||
'messages': request.messages,
|
||||
'tools': getattr(request, 'tools', None),
|
||||
'stream': True,
|
||||
'max_tokens': max_tokens
|
||||
}
|
||||
|
||||
# Add system prompt if present
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
stream = self.bot.call_with_tools(**kwargs)
|
||||
|
||||
# Convert stream format to our expected format
|
||||
for chunk in stream:
|
||||
yield self._format_stream_chunk(chunk)
|
||||
else:
|
||||
bot_type = type(self.bot).__name__
|
||||
raise NotImplementedError(f"Bot {bot_type} does not support call_with_tools. Please add the method.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AgentLLMModel call_stream error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _format_response(self, response):
|
||||
"""Format Claude response to our expected format"""
|
||||
# This would need to be implemented based on Claude's response format
|
||||
return response
|
||||
|
||||
def _format_stream_chunk(self, chunk):
|
||||
"""Format Claude stream chunk to our expected format"""
|
||||
# This would need to be implemented based on Claude's stream format
|
||||
return chunk
|
||||
|
||||
|
||||
class AgentBridge:
|
||||
"""
|
||||
Bridge class that integrates single super Agent with COW
|
||||
"""
|
||||
|
||||
def __init__(self, bridge: Bridge):
|
||||
self.bridge = bridge
|
||||
self.agent: Optional[Agent] = None
|
||||
|
||||
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
|
||||
"""
|
||||
Create the super agent with COW integration
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt
|
||||
tools: List of tools (optional)
|
||||
**kwargs: Additional agent parameters
|
||||
|
||||
Returns:
|
||||
Agent instance
|
||||
"""
|
||||
# Create LLM model that uses COW's bot infrastructure
|
||||
model = AgentLLMModel(self.bridge)
|
||||
|
||||
# Default tools if none provided
|
||||
if tools is None:
|
||||
# Use ToolManager to load all available tools
|
||||
from agent.tools import ToolManager
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.load_tools()
|
||||
|
||||
tools = []
|
||||
for tool_name in tool_manager.tool_classes.keys():
|
||||
try:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
if tool:
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
|
||||
|
||||
# Create the single super agent
|
||||
self.agent = Agent(
|
||||
system_prompt=system_prompt,
|
||||
description=kwargs.get("description", "AI Super Agent"),
|
||||
model=model,
|
||||
tools=tools,
|
||||
max_steps=kwargs.get("max_steps", 15),
|
||||
output_mode=kwargs.get("output_mode", "logger"),
|
||||
workspace_dir=kwargs.get("workspace_dir"), # Pass workspace for skills loading
|
||||
enable_skills=kwargs.get("enable_skills", True), # Enable skills by default
|
||||
memory_manager=kwargs.get("memory_manager"), # Pass memory manager
|
||||
max_context_tokens=kwargs.get("max_context_tokens"),
|
||||
context_reserve_tokens=kwargs.get("context_reserve_tokens")
|
||||
)
|
||||
|
||||
# Log skill loading details
|
||||
if self.agent.skill_manager:
|
||||
logger.info(f"[AgentBridge] SkillManager initialized:")
|
||||
logger.info(f"[AgentBridge] - Managed dir: {self.agent.skill_manager.managed_skills_dir}")
|
||||
logger.info(f"[AgentBridge] - Workspace dir: {self.agent.skill_manager.workspace_dir}")
|
||||
logger.info(f"[AgentBridge] - Total skills: {len(self.agent.skill_manager.skills)}")
|
||||
for skill_name in self.agent.skill_manager.skills.keys():
|
||||
logger.info(f"[AgentBridge] * {skill_name}")
|
||||
|
||||
return self.agent
|
||||
|
||||
def get_agent(self) -> Optional[Agent]:
|
||||
"""Get the super agent, create if not exists"""
|
||||
if self.agent is None:
|
||||
self._init_default_agent()
|
||||
return self.agent
|
||||
|
||||
def _init_default_agent(self):
|
||||
"""Initialize default super agent with new prompt system"""
|
||||
from config import conf
|
||||
import os
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
|
||||
# Initialize workspace and create template files
|
||||
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
|
||||
|
||||
workspace_files = ensure_workspace(workspace_root, create_templates=True)
|
||||
logger.info(f"[AgentBridge] Workspace initialized at: {workspace_root}")
|
||||
|
||||
# Setup memory system
|
||||
memory_manager = None
|
||||
memory_tools = []
|
||||
|
||||
try:
|
||||
# Try to initialize memory system
|
||||
from agent.memory import MemoryManager, MemoryConfig
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
workspace_root=workspace_root,
|
||||
embedding_provider="local", # Use local embedding (no API key needed)
|
||||
embedding_model="all-MiniLM-L6-v2"
|
||||
)
|
||||
|
||||
# Create memory manager with the config
|
||||
memory_manager = MemoryManager(memory_config)
|
||||
|
||||
# Create memory tools
|
||||
memory_tools = [
|
||||
MemorySearchTool(memory_manager),
|
||||
MemoryGetTool(memory_manager)
|
||||
]
|
||||
|
||||
logger.info(f"[AgentBridge] Memory system initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Memory system not available: {e}")
|
||||
logger.info("[AgentBridge] Continuing without memory features")
|
||||
|
||||
# Use ToolManager to dynamically load all available tools
|
||||
from agent.tools import ToolManager
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.load_tools()
|
||||
|
||||
# Create tool instances for all available tools
|
||||
tools = []
|
||||
file_config = {
|
||||
"cwd": workspace_root,
|
||||
"memory_manager": memory_manager
|
||||
} if memory_manager else {"cwd": workspace_root}
|
||||
|
||||
for tool_name in tool_manager.tool_classes.keys():
|
||||
try:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
if tool:
|
||||
# Apply workspace config to file operation tools
|
||||
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls']:
|
||||
tool.config = file_config
|
||||
tool.cwd = file_config.get("cwd", tool.cwd if hasattr(tool, 'cwd') else None)
|
||||
if 'memory_manager' in file_config:
|
||||
tool.memory_manager = file_config['memory_manager']
|
||||
tools.append(tool)
|
||||
logger.debug(f"[AgentBridge] Loaded tool: {tool_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
|
||||
|
||||
# Add memory tools
|
||||
if memory_tools:
|
||||
tools.extend(memory_tools)
|
||||
logger.info(f"[AgentBridge] Added {len(memory_tools)} memory tools")
|
||||
|
||||
logger.info(f"[AgentBridge] Loaded {len(tools)} tools: {[t.name for t in tools]}")
|
||||
|
||||
# Load context files (SOUL.md, USER.md, etc.)
|
||||
context_files = load_context_files(workspace_root)
|
||||
logger.info(f"[AgentBridge] Loaded {len(context_files)} context files: {[f.path for f in context_files]}")
|
||||
|
||||
# Build system prompt using new prompt builder
|
||||
prompt_builder = PromptBuilder(
|
||||
workspace_dir=workspace_root,
|
||||
language="zh"
|
||||
)
|
||||
|
||||
# Get runtime info
|
||||
runtime_info = {
|
||||
"model": conf().get("model", "unknown"),
|
||||
"workspace": workspace_root,
|
||||
"channel": "web" # TODO: get from actual channel, default to "web" to hide if not specified
|
||||
}
|
||||
|
||||
system_prompt = prompt_builder.build(
|
||||
tools=tools,
|
||||
context_files=context_files,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info
|
||||
)
|
||||
|
||||
logger.info("[AgentBridge] System prompt built successfully")
|
||||
|
||||
# Create agent with configured tools and workspace
|
||||
agent = self.create_agent(
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
max_steps=50,
|
||||
output_mode="logger",
|
||||
workspace_dir=workspace_root, # Pass workspace to agent for skills loading
|
||||
enable_skills=True # Enable skills auto-loading
|
||||
)
|
||||
|
||||
# Attach memory manager to agent if available
|
||||
if memory_manager:
|
||||
agent.memory_manager = memory_manager
|
||||
logger.info(f"[AgentBridge] Memory manager attached to agent")
|
||||
|
||||
def agent_reply(self, query: str, context: Context = None,
|
||||
on_event=None, clear_history: bool = False) -> Reply:
|
||||
"""
|
||||
Use super agent to reply to a query
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
context: COW context (optional)
|
||||
on_event: Event callback (optional)
|
||||
clear_history: Whether to clear conversation history
|
||||
|
||||
Returns:
|
||||
Reply object
|
||||
"""
|
||||
try:
|
||||
# Get agent (will auto-initialize if needed)
|
||||
agent = self.get_agent()
|
||||
if not agent:
|
||||
return Reply(ReplyType.ERROR, "Failed to initialize super agent")
|
||||
|
||||
# Use agent's run_stream method
|
||||
response = agent.run_stream(
|
||||
user_message=query,
|
||||
on_event=on_event,
|
||||
clear_history=clear_history
|
||||
)
|
||||
|
||||
return Reply(ReplyType.TEXT, response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent reply error: {e}")
|
||||
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")
|
||||
@@ -23,7 +23,7 @@ class Bridge(object):
|
||||
if bot_type:
|
||||
self.btype["chat"] = bot_type
|
||||
else:
|
||||
model_type = conf().get("model") or const.GPT35
|
||||
model_type = conf().get("model") or const.GPT_41_MINI
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
@@ -64,6 +64,7 @@ class Bridge(object):
|
||||
|
||||
self.bots = {}
|
||||
self.chat_bots = {}
|
||||
self._agent_bridge = None
|
||||
|
||||
# 模型对应的接口
|
||||
def get_bot(self, typename):
|
||||
@@ -104,3 +105,29 @@ class Bridge(object):
|
||||
重置bot路由
|
||||
"""
|
||||
self.__init__()
|
||||
|
||||
def get_agent_bridge(self):
|
||||
"""
|
||||
Get agent bridge for agent-based conversations
|
||||
"""
|
||||
if self._agent_bridge is None:
|
||||
from bridge.agent_bridge import AgentBridge
|
||||
self._agent_bridge = AgentBridge(self)
|
||||
return self._agent_bridge
|
||||
|
||||
def fetch_agent_reply(self, query: str, context: Context = None,
|
||||
on_event=None, clear_history: bool = False) -> Reply:
|
||||
"""
|
||||
Use super agent to handle the query
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
context: Context object
|
||||
on_event: Event callback for streaming
|
||||
clear_history: Whether to clear conversation history
|
||||
|
||||
Returns:
|
||||
Reply object
|
||||
"""
|
||||
agent_bridge = self.get_agent_bridge()
|
||||
return agent_bridge.agent_reply(query, context, on_event, clear_history)
|
||||
|
||||
@@ -5,6 +5,8 @@ Message sending channel abstract class
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import *
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class Channel(object):
|
||||
@@ -35,7 +37,30 @@ class Channel(object):
|
||||
raise NotImplementedError
|
||||
|
||||
def build_reply_content(self, query, context: Context = None) -> Reply:
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
"""
|
||||
Build reply content, using agent if enabled in config
|
||||
"""
|
||||
# Check if agent mode is enabled
|
||||
use_agent = conf().get("agent", False)
|
||||
|
||||
if use_agent:
|
||||
try:
|
||||
logger.info("[Channel] Using agent mode")
|
||||
|
||||
# Use agent bridge to handle the query
|
||||
return Bridge().fetch_agent_reply(
|
||||
query=query,
|
||||
context=context,
|
||||
on_event=None,
|
||||
clear_history=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Channel] Agent mode failed, fallback to normal mode: {e}")
|
||||
# Fallback to normal mode if agent fails
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
else:
|
||||
# Normal mode
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
|
||||
def build_voice_to_text(self, voice_file) -> Reply:
|
||||
return Bridge().fetch_voice_to_text(voice_file)
|
||||
|
||||
167
channel/feishu/README.md
Normal file
167
channel/feishu/README.md
Normal file
@@ -0,0 +1,167 @@
|
||||
# 飞书Channel使用说明
|
||||
|
||||
飞书Channel支持两种事件接收模式,可以根据部署环境灵活选择。
|
||||
|
||||
## 模式对比
|
||||
|
||||
| 模式 | 适用场景 | 优点 | 缺点 |
|
||||
|------|---------|------|------|
|
||||
| **webhook** | 生产环境 | 稳定可靠,官方推荐 | 需要公网IP或域名 |
|
||||
| **websocket** | 本地开发 | 无需公网IP,开发便捷 | 需要额外依赖 |
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 基础配置
|
||||
|
||||
在 `config.json` 中添加以下配置:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "feishu",
|
||||
"feishu_app_id": "cli_xxxxx",
|
||||
"feishu_app_secret": "your_app_secret",
|
||||
"feishu_token": "your_verification_token",
|
||||
"feishu_bot_name": "你的机器人名称",
|
||||
"feishu_event_mode": "webhook",
|
||||
"feishu_port": 9891
|
||||
}
|
||||
```
|
||||
|
||||
### 配置项说明
|
||||
|
||||
- `feishu_app_id`: 飞书应用的App ID
|
||||
- `feishu_app_secret`: 飞书应用的App Secret
|
||||
- `feishu_token`: 事件订阅的Verification Token
|
||||
- `feishu_bot_name`: 机器人名称(用于群聊@判断)
|
||||
- `feishu_event_mode`: 事件接收模式,可选值:
|
||||
- `"websocket"`: 长连接模式(默认)
|
||||
- `"webhook"`: HTTP服务器模式
|
||||
- `feishu_port`: webhook模式下的HTTP服务端口(默认9891)
|
||||
|
||||
## 模式一: Webhook模式(推荐生产环境)
|
||||
|
||||
### 1. 配置
|
||||
|
||||
```json
|
||||
{
|
||||
"feishu_event_mode": "webhook",
|
||||
"feishu_port": 9891
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 启动服务
|
||||
|
||||
```bash
|
||||
python3 app.py
|
||||
```
|
||||
|
||||
服务将在 `http://0.0.0.0:9891` 启动。
|
||||
|
||||
### 3. 配置飞书应用
|
||||
|
||||
1. 登录[飞书开放平台](https://open.feishu.cn/)
|
||||
2. 进入应用详情 -> 事件订阅
|
||||
3. 选择 **将事件发送至开发者服务器**
|
||||
4. 填写请求地址: `http://your-domain:9891/`
|
||||
5. 添加事件: `im.message.receive_v1` (接收消息v2.0)
|
||||
6. 保存配置
|
||||
|
||||
### 4. 注意事项
|
||||
|
||||
- 需要有公网IP或域名
|
||||
- 确保防火墙开放对应端口
|
||||
- 建议使用HTTPS(需要配置反向代理)
|
||||
|
||||
## 模式二: WebSocket模式(推荐本地开发)
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
pip install lark-oapi
|
||||
```
|
||||
|
||||
### 2. 配置
|
||||
|
||||
```json
|
||||
{
|
||||
"feishu_event_mode": "websocket"
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 启动服务
|
||||
|
||||
```bash
|
||||
python3 app.py
|
||||
```
|
||||
|
||||
程序将自动建立与飞书开放平台的长连接。
|
||||
|
||||
### 4. 配置飞书应用
|
||||
|
||||
1. 登录[飞书开放平台](https://open.feishu.cn/)
|
||||
2. 进入应用详情 -> 事件订阅
|
||||
3. 选择 **使用长连接接收事件**
|
||||
4. 添加事件: `im.message.receive_v1` (接收消息v2.0)
|
||||
5. 保存配置
|
||||
|
||||
### 5. 注意事项
|
||||
|
||||
- 无需公网IP
|
||||
- 需要能访问公网(建立WebSocket连接)
|
||||
- 每个应用最多50个连接
|
||||
- 集群模式下消息随机分发到一个客户端
|
||||
|
||||
## 平滑迁移
|
||||
|
||||
从webhook模式切换到websocket模式(或反向切换):
|
||||
|
||||
1. 修改 `config.json` 中的 `feishu_event_mode`
|
||||
2. 如果切换到websocket模式,安装 `lark-oapi` 依赖
|
||||
3. 重启服务
|
||||
4. 在飞书开放平台修改事件订阅方式
|
||||
|
||||
**重要**: 同一时间只能使用一种模式,否则会导致消息重复接收。
|
||||
|
||||
## 消息去重机制
|
||||
|
||||
两种模式都使用相同的消息去重机制:
|
||||
|
||||
- 使用 `ExpiredDict` 存储已处理的消息ID
|
||||
- 过期时间: 7.1小时
|
||||
- 确保消息不会重复处理
|
||||
|
||||
## 故障排查
|
||||
|
||||
### WebSocket模式连接失败
|
||||
|
||||
```
|
||||
[FeiShu] lark_oapi not installed
|
||||
```
|
||||
|
||||
**解决**: 安装依赖 `pip install lark-oapi`
|
||||
|
||||
### Webhook模式端口被占用
|
||||
|
||||
```
|
||||
Address already in use
|
||||
```
|
||||
|
||||
**解决**: 修改 `feishu_port` 配置或关闭占用端口的进程
|
||||
|
||||
### 收不到消息
|
||||
|
||||
1. 检查飞书应用的事件订阅配置
|
||||
2. 确认已添加 `im.message.receive_v1` 事件
|
||||
3. 检查应用权限: 需要 `im:message` 权限
|
||||
4. 查看日志中的错误信息
|
||||
|
||||
## 开发建议
|
||||
|
||||
- **本地开发**: 使用websocket模式,快速迭代
|
||||
- **测试环境**: 可以使用webhook模式 + 内网穿透工具(如ngrok)
|
||||
- **生产环境**: 使用webhook模式,配置正式域名和HTTPS
|
||||
|
||||
## 参考文档
|
||||
|
||||
- [飞书开放平台 - 事件订阅](https://open.feishu.cn/document/ukTMukTMukTM/uUTNz4SN1MjL1UzM)
|
||||
- [飞书SDK - Python](https://github.com/larksuite/oapi-sdk-python)
|
||||
@@ -1,48 +1,80 @@
|
||||
"""
|
||||
飞书通道接入
|
||||
|
||||
支持两种事件接收模式:
|
||||
1. webhook模式: 通过HTTP服务器接收事件(需要公网IP)
|
||||
2. websocket模式: 通过长连接接收事件(本地开发友好)
|
||||
|
||||
通过配置项 feishu_event_mode 选择模式: "webhook" 或 "websocket"
|
||||
|
||||
@author Saboteur7
|
||||
@Date 2023/11/19
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
# -*- coding=utf-8 -*-
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
import web
|
||||
from channel.feishu.feishu_message import FeishuMessage
|
||||
|
||||
from bridge.context import Context
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.feishu.feishu_message import FeishuMessage
|
||||
from common import utils
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
from common.expired_dict import ExpiredDict
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from common import utils
|
||||
import json
|
||||
import os
|
||||
|
||||
URL_VERIFICATION = "url_verification"
|
||||
|
||||
# 尝试导入飞书SDK,如果未安装则websocket模式不可用
|
||||
try:
|
||||
import lark_oapi as lark
|
||||
|
||||
LARK_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
LARK_SDK_AVAILABLE = False
|
||||
logger.warning(
|
||||
"[FeiShu] lark_oapi not installed, websocket mode is not available. Install with: pip install lark-oapi")
|
||||
|
||||
|
||||
@singleton
|
||||
class FeiShuChanel(ChatChannel):
|
||||
feishu_app_id = conf().get('feishu_app_id')
|
||||
feishu_app_secret = conf().get('feishu_app_secret')
|
||||
feishu_token = conf().get('feishu_token')
|
||||
feishu_event_mode = conf().get('feishu_event_mode', 'websocket') # webhook 或 websocket
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
logger.info("[FeiShu] app_id={}, app_secret={} verification_token={}".format(
|
||||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token))
|
||||
logger.info("[FeiShu] app_id={}, app_secret={}, verification_token={}, event_mode={}".format(
|
||||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token, self.feishu_event_mode))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# 验证配置
|
||||
if self.feishu_event_mode == 'websocket' and not LARK_SDK_AVAILABLE:
|
||||
logger.error("[FeiShu] websocket mode requires lark_oapi. Please install: pip install lark-oapi")
|
||||
raise Exception("lark_oapi not installed")
|
||||
|
||||
def startup(self):
|
||||
if self.feishu_event_mode == 'websocket':
|
||||
self._startup_websocket()
|
||||
else:
|
||||
self._startup_webhook()
|
||||
|
||||
def _startup_webhook(self):
|
||||
"""启动HTTP服务器接收事件(webhook模式)"""
|
||||
logger.info("[FeiShu] Starting in webhook mode...")
|
||||
urls = (
|
||||
'/', 'channel.feishu.feishu_channel.FeishuController'
|
||||
)
|
||||
@@ -50,6 +82,109 @@ class FeiShuChanel(ChatChannel):
|
||||
port = conf().get("feishu_port", 9891)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
def _startup_websocket(self):
|
||||
"""启动长连接接收事件(websocket模式)"""
|
||||
logger.info("[FeiShu] Starting in websocket mode...")
|
||||
|
||||
# 创建事件处理器
|
||||
def handle_message_event(data: lark.im.v1.P2ImMessageReceiveV1) -> None:
|
||||
"""处理接收消息事件 v2.0"""
|
||||
try:
|
||||
logger.debug(f"[FeiShu] websocket receive event: {lark.JSON.marshal(data, indent=2)}")
|
||||
|
||||
# 转换为标准的event格式
|
||||
event_dict = json.loads(lark.JSON.marshal(data))
|
||||
event = event_dict.get("event", {})
|
||||
|
||||
# 处理消息
|
||||
self._handle_message_event(event)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] websocket handle message error: {e}", exc_info=True)
|
||||
|
||||
# 构建事件分发器
|
||||
event_handler = lark.EventDispatcherHandler.builder("", "") \
|
||||
.register_p2_im_message_receive_v1(handle_message_event) \
|
||||
.build()
|
||||
|
||||
# 创建长连接客户端
|
||||
ws_client = lark.ws.Client(
|
||||
self.feishu_app_id,
|
||||
self.feishu_app_secret,
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.DEBUG if conf().get("debug") else lark.LogLevel.INFO
|
||||
)
|
||||
|
||||
# 在新线程中启动客户端,避免阻塞主线程
|
||||
def start_client():
|
||||
try:
|
||||
logger.info("[FeiShu] Websocket client starting...")
|
||||
ws_client.start()
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] Websocket client error: {e}", exc_info=True)
|
||||
|
||||
ws_thread = threading.Thread(target=start_client, daemon=True)
|
||||
ws_thread.start()
|
||||
|
||||
# 保持主线程运行
|
||||
logger.info("[FeiShu] Websocket mode started, waiting for events...")
|
||||
ws_thread.join()
|
||||
|
||||
def _handle_message_event(self, event: dict):
|
||||
"""
|
||||
处理消息事件的核心逻辑
|
||||
webhook和websocket模式共用此方法
|
||||
"""
|
||||
if not event.get("message") or not event.get("sender"):
|
||||
logger.warning(f"[FeiShu] invalid message, event={event}")
|
||||
return
|
||||
|
||||
msg = event.get("message")
|
||||
|
||||
# 幂等判断
|
||||
msg_id = msg.get("message_id")
|
||||
if self.receivedMsgs.get(msg_id):
|
||||
logger.warning(f"[FeiShu] repeat msg filtered, msg_id={msg_id}")
|
||||
return
|
||||
self.receivedMsgs[msg_id] = True
|
||||
|
||||
is_group = False
|
||||
chat_type = msg.get("chat_type")
|
||||
|
||||
if chat_type == "group":
|
||||
if not msg.get("mentions") and msg.get("message_type") == "text":
|
||||
# 群聊中未@不响应
|
||||
return
|
||||
if msg.get("mentions") and msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get(
|
||||
"message_type") == "text":
|
||||
# 不是@机器人,不响应
|
||||
return
|
||||
# 群聊
|
||||
is_group = True
|
||||
receive_id_type = "chat_id"
|
||||
elif chat_type == "p2p":
|
||||
receive_id_type = "open_id"
|
||||
else:
|
||||
logger.warning("[FeiShu] message ignore")
|
||||
return
|
||||
|
||||
# 构造飞书消息对象
|
||||
feishu_msg = FeishuMessage(event, is_group=is_group, access_token=self.fetch_access_token())
|
||||
if not feishu_msg:
|
||||
return
|
||||
|
||||
context = self._compose_context(
|
||||
feishu_msg.ctype,
|
||||
feishu_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=feishu_msg,
|
||||
receive_id_type=receive_id_type,
|
||||
no_need_at=True
|
||||
)
|
||||
if context:
|
||||
self.produce(context)
|
||||
logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context["isgroup"]
|
||||
@@ -143,9 +278,39 @@ class FeiShuChanel(ChatChannel):
|
||||
os.remove(temp_name)
|
||||
return upload_response.json().get("data").get("image_key")
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
# 1.文本请求
|
||||
# 图片生成处理
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
elif context.type == ContextType.VOICE:
|
||||
# 2.语音请求
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class FeishuController:
|
||||
"""
|
||||
HTTP服务器控制器,用于webhook模式
|
||||
"""
|
||||
# 类常量
|
||||
FAILED_MSG = '{"success": false}'
|
||||
SUCCESS_MSG = '{"success": true}'
|
||||
@@ -175,80 +340,10 @@ class FeishuController:
|
||||
# 处理消息事件
|
||||
event = request.get("event")
|
||||
if header.get("event_type") == self.MESSAGE_RECEIVE_TYPE and event:
|
||||
if not event.get("message") or not event.get("sender"):
|
||||
logger.warning(f"[FeiShu] invalid message, msg={request}")
|
||||
return self.FAILED_MSG
|
||||
msg = event.get("message")
|
||||
channel._handle_message_event(event)
|
||||
|
||||
# 幂等判断
|
||||
if channel.receivedMsgs.get(msg.get("message_id")):
|
||||
logger.warning(f"[FeiShu] repeat msg filtered, event_id={header.get('event_id')}")
|
||||
return self.SUCCESS_MSG
|
||||
channel.receivedMsgs[msg.get("message_id")] = True
|
||||
|
||||
is_group = False
|
||||
chat_type = msg.get("chat_type")
|
||||
if chat_type == "group":
|
||||
if not msg.get("mentions") and msg.get("message_type") == "text":
|
||||
# 群聊中未@不响应
|
||||
return self.SUCCESS_MSG
|
||||
if msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get("message_type") == "text":
|
||||
# 不是@机器人,不响应
|
||||
return self.SUCCESS_MSG
|
||||
# 群聊
|
||||
is_group = True
|
||||
receive_id_type = "chat_id"
|
||||
elif chat_type == "p2p":
|
||||
receive_id_type = "open_id"
|
||||
else:
|
||||
logger.warning("[FeiShu] message ignore")
|
||||
return self.SUCCESS_MSG
|
||||
# 构造飞书消息对象
|
||||
feishu_msg = FeishuMessage(event, is_group=is_group, access_token=channel.fetch_access_token())
|
||||
if not feishu_msg:
|
||||
return self.SUCCESS_MSG
|
||||
|
||||
context = self._compose_context(
|
||||
feishu_msg.ctype,
|
||||
feishu_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=feishu_msg,
|
||||
receive_id_type=receive_id_type,
|
||||
no_need_at=True
|
||||
)
|
||||
if context:
|
||||
channel.produce(context)
|
||||
logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
|
||||
return self.SUCCESS_MSG
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.FAILED_MSG
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
# 1.文本请求
|
||||
# 图片生成处理
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
elif context.type == ContextType.VOICE:
|
||||
# 2.语音请求
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
@@ -150,9 +150,6 @@ class WebChannel(ChatChannel):
|
||||
Poll for responses using the session_id.
|
||||
"""
|
||||
try:
|
||||
# 不记录轮询请求的日志
|
||||
web.ctx.log_request = False
|
||||
|
||||
data = web.data()
|
||||
json_data = json.loads(data)
|
||||
session_id = json_data.get('session_id')
|
||||
@@ -198,7 +195,7 @@ class WebChannel(ChatChannel):
|
||||
5. wechatcom_app: 企微自建应用
|
||||
6. dingtalk: 钉钉
|
||||
7. feishu: 飞书""")
|
||||
logger.info(f"Web对话网页已运行, 请使用浏览器访问 http://localhost:{port}/chat(本地运行)或 http://ip:{port}/chat(服务器运行) ")
|
||||
logger.info(f"Web对话网页已运行, 请使用浏览器访问 http://localhost:{port}/chat (本地运行) 或 http://ip:{port}/chat (服务器运行)")
|
||||
|
||||
# 确保静态文件目录存在
|
||||
static_dir = os.path.join(os.path.dirname(__file__), 'static')
|
||||
@@ -215,19 +212,20 @@ class WebChannel(ChatChannel):
|
||||
)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
|
||||
# 禁用web.py的默认日志输出
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
# 完全禁用web.py的HTTP日志输出
|
||||
# 创建一个空的日志处理函数
|
||||
def null_log_function(status, environ):
|
||||
pass
|
||||
|
||||
# 配置web.py的日志级别为ERROR,只显示错误
|
||||
# 替换web.py的日志函数
|
||||
web.httpserver.LogMiddleware.log = lambda self, status, environ: None
|
||||
|
||||
# 配置web.py的日志级别为ERROR
|
||||
logging.getLogger("web").setLevel(logging.ERROR)
|
||||
|
||||
# 禁用web.httpserver的日志
|
||||
logging.getLogger("web.httpserver").setLevel(logging.ERROR)
|
||||
|
||||
# 临时重定向标准输出,捕获web.py的启动消息
|
||||
with redirect_stdout(io.StringIO()):
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
# 启动服务器
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
|
||||
class RootHandler:
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"@bot"
|
||||
],
|
||||
"group_name_white_list": [
|
||||
"ChatGPT测试群",
|
||||
"Agent测试群",
|
||||
"ChatGPT测试群2"
|
||||
],
|
||||
"image_create_prefix": [
|
||||
@@ -30,8 +30,9 @@
|
||||
"expires_in_seconds": 3600,
|
||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"temperature": 0.7,
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": ""
|
||||
"linkai_app_code": "",
|
||||
"agent": false
|
||||
}
|
||||
|
||||
22
config.py
22
config.py
@@ -1,10 +1,10 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import copy
|
||||
|
||||
from common.log import logger
|
||||
|
||||
@@ -148,6 +148,7 @@ available_setting = {
|
||||
"feishu_app_secret": "", # 飞书机器人APP secret
|
||||
"feishu_token": "", # 飞书 verification token
|
||||
"feishu_bot_name": "", # 飞书机器人的名字
|
||||
"feishu_event_mode": "websocket", # 飞书事件接收模式: webhook(HTTP服务器) 或 websocket(长连接)
|
||||
# 钉钉配置
|
||||
"dingtalk_client_id": "", # 钉钉机器人Client ID
|
||||
"dingtalk_client_secret": "", # 钉钉机器人Client Secret
|
||||
@@ -183,6 +184,8 @@ available_setting = {
|
||||
"Minimax_group_id": "",
|
||||
"Minimax_base_url": "",
|
||||
"web_port": 9899,
|
||||
"agent": False, # 是否开启Agent模式
|
||||
"agent_workspace": "~/cow" # agent工作空间路径,用于存储skills、memory等
|
||||
}
|
||||
|
||||
|
||||
@@ -197,16 +200,26 @@ class Config(dict):
|
||||
self.user_datas = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in available_setting:
|
||||
# 跳过以下划线开头的注释字段
|
||||
if not key.startswith("_") and key not in available_setting:
|
||||
raise Exception("key {} not in available_setting".format(key))
|
||||
return super().__getitem__(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key not in available_setting:
|
||||
# 跳过以下划线开头的注释字段
|
||||
if not key.startswith("_") and key not in available_setting:
|
||||
raise Exception("key {} not in available_setting".format(key))
|
||||
return super().__setitem__(key, value)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# 跳过以下划线开头的注释字段
|
||||
if key.startswith("_"):
|
||||
return super().get(key, default)
|
||||
|
||||
# 如果key不在available_setting中,直接返回default
|
||||
if key not in available_setting:
|
||||
return super().get(key, default)
|
||||
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError as e:
|
||||
@@ -284,6 +297,9 @@ def load_config():
|
||||
# Some online deployment platforms (e.g. Railway) deploy project from github directly. So you shouldn't put your secrets like api key in a config file, instead use environment variables to override the default config.
|
||||
for name, value in os.environ.items():
|
||||
name = name.lower()
|
||||
# 跳过以下划线开头的注释字段
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
if name in available_setting:
|
||||
logger.info("[INIT] override config by environ args: {}={}".format(name, value))
|
||||
try:
|
||||
|
||||
@@ -15,18 +15,14 @@ elevenlabs==1.0.3 # elevenlabs TTS
|
||||
#install plugin
|
||||
dulwich
|
||||
|
||||
# wechatmp && wechatcom
|
||||
# wechatmp && wechatcom && feishu
|
||||
web.py
|
||||
wechatpy
|
||||
|
||||
# chatgpt-tool-hub plugin
|
||||
chatgpt_tool_hub==0.5.0
|
||||
|
||||
# xunfei spark
|
||||
websocket-client==1.2.0
|
||||
|
||||
# claude bot
|
||||
curl_cffi
|
||||
|
||||
# claude API
|
||||
anthropic==0.25.0
|
||||
|
||||
|
||||
@@ -9,3 +9,6 @@ pre-commit
|
||||
web.py
|
||||
linkai>=0.0.6.0
|
||||
agentmesh-sdk>=0.1.3
|
||||
|
||||
# feishu websocket mode
|
||||
lark-oapi
|
||||
124
skills/README.md
Normal file
124
skills/README.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# Skills Directory
|
||||
|
||||
This directory contains skills for the COW agent system. Skills are markdown files that provide specialized instructions for specific tasks.
|
||||
|
||||
## What are Skills?
|
||||
|
||||
Skills are reusable instruction sets that help the agent perform specific tasks more effectively. Each skill:
|
||||
|
||||
- Provides context-specific guidance
|
||||
- Documents best practices
|
||||
- Includes examples and usage patterns
|
||||
- Can have requirements (binaries, environment variables, etc.)
|
||||
|
||||
## Skill Structure
|
||||
|
||||
Each skill is a markdown file (`SKILL.md`) in its own directory with frontmatter:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: skill-name
|
||||
description: Brief description of what the skill does
|
||||
metadata: {"cow":{"emoji":"🎯","requires":{"bins":["tool"]}}}
|
||||
---
|
||||
|
||||
# Skill Name
|
||||
|
||||
Detailed instructions and examples...
|
||||
```
|
||||
|
||||
## Available Skills
|
||||
|
||||
- **calculator**: Mathematical calculations and expressions
|
||||
- **web-search**: Search the web for current information
|
||||
- **file-operations**: Read, write, and manage files
|
||||
|
||||
## Creating Custom Skills
|
||||
|
||||
To create a new skill:
|
||||
|
||||
1. Create a directory: `skills/my-skill/`
|
||||
2. Create `SKILL.md` with frontmatter and content
|
||||
3. Restart the agent to load the new skill
|
||||
|
||||
### Frontmatter Fields
|
||||
|
||||
- `name`: Skill name (must match directory name)
|
||||
- `description`: Brief description (required)
|
||||
- `metadata`: JSON object with additional configuration
|
||||
- `emoji`: Display emoji
|
||||
- `always`: Always include this skill (default: false)
|
||||
- `primaryEnv`: Primary environment variable needed
|
||||
- `os`: Supported operating systems (e.g., ["darwin", "linux"])
|
||||
- `requires`: Requirements object
|
||||
- `bins`: Required binaries
|
||||
- `env`: Required environment variables
|
||||
- `config`: Required config paths
|
||||
- `disable-model-invocation`: If true, skill won't be shown to model (default: false)
|
||||
- `user-invocable`: If false, users can't invoke directly (default: true)
|
||||
|
||||
### Example Skill
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: my-tool
|
||||
description: Use my-tool to process data
|
||||
metadata: {"cow":{"emoji":"🔧","requires":{"bins":["my-tool"],"env":["MY_TOOL_API_KEY"]}}}
|
||||
---
|
||||
|
||||
# My Tool Skill
|
||||
|
||||
Use this skill when you need to process data with my-tool.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Install my-tool: `pip install my-tool`
|
||||
- Set `MY_TOOL_API_KEY` environment variable
|
||||
|
||||
## Usage
|
||||
|
||||
\`\`\`python
|
||||
# Example usage
|
||||
my_tool_command("input data")
|
||||
\`\`\`
|
||||
```
|
||||
|
||||
## Skill Loading
|
||||
|
||||
Skills are loaded from multiple locations with precedence:
|
||||
|
||||
1. **Workspace skills** (highest): `workspace/skills/` - Project-specific skills
|
||||
2. **Managed skills**: `~/.cow/skills/` - User-installed skills
|
||||
3. **Bundled skills** (lowest): Built-in skills
|
||||
|
||||
Skills with the same name in higher-precedence locations override lower ones.
|
||||
|
||||
## Skill Requirements
|
||||
|
||||
Skills can specify requirements that determine when they're available:
|
||||
|
||||
- **OS requirements**: Only load on specific operating systems
|
||||
- **Binary requirements**: Only load if required binaries are installed
|
||||
- **Environment variables**: Only load if required env vars are set
|
||||
- **Config requirements**: Only load if config values are set
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Clear descriptions**: Write clear, concise skill descriptions
|
||||
2. **Include examples**: Provide practical usage examples
|
||||
3. **Document prerequisites**: List all requirements clearly
|
||||
4. **Use appropriate metadata**: Set correct requirements and flags
|
||||
5. **Keep skills focused**: Each skill should have a single, clear purpose
|
||||
|
||||
## Workspace Skills
|
||||
|
||||
You can create workspace-specific skills in your agent's workspace:
|
||||
|
||||
```
|
||||
workspace/
|
||||
skills/
|
||||
custom-skill/
|
||||
SKILL.md
|
||||
```
|
||||
|
||||
These skills are only available when working in that specific workspace.
|
||||
202
skills/skill-creator/LICENSE.txt
Normal file
202
skills/skill-creator/LICENSE.txt
Normal file
@@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
286
skills/skill-creator/SKILL.md
Normal file
286
skills/skill-creator/SKILL.md
Normal file
@@ -0,0 +1,286 @@
|
||||
---
|
||||
name: skill-creator
|
||||
description: Create or update skills. Use when designing, structuring, or packaging skills with scripts, references, and assets. COW simplified version - skills are used locally in workspace.
|
||||
license: Complete terms in LICENSE.txt
|
||||
---
|
||||
|
||||
# Skill Creator
|
||||
|
||||
This skill provides guidance for creating effective skills using the existing tool system.
|
||||
|
||||
## About Skills
|
||||
|
||||
Skills are modular, self-contained packages that extend the agent's capabilities by providing specialized knowledge, workflows, and tools. They transform a general-purpose agent into a specialized agent equipped with procedural knowledge.
|
||||
|
||||
### What Skills Provide
|
||||
|
||||
1. **Specialized workflows** - Multi-step procedures for specific domains
|
||||
2. **Tool integrations** - Instructions for working with specific file formats or APIs
|
||||
3. **Domain expertise** - Company-specific knowledge, schemas, business logic
|
||||
4. **Bundled resources** - Scripts, references, and assets for complex tasks
|
||||
|
||||
### Core Principle
|
||||
|
||||
**Concise is Key**: Only add context the agent doesn't already have. Challenge each piece of information: "Does this justify its token cost?" Prefer concise examples over verbose explanations.
|
||||
|
||||
## Skill Structure
|
||||
|
||||
Every skill consists of a required SKILL.md file and optional bundled resources:
|
||||
|
||||
```
|
||||
skill-name/
|
||||
├── SKILL.md (required)
|
||||
│ ├── YAML frontmatter metadata (required)
|
||||
│ │ ├── name: (required)
|
||||
│ │ └── description: (required)
|
||||
│ └── Markdown instructions (required)
|
||||
└── Bundled Resources (optional)
|
||||
├── scripts/ - Executable code (Python/Bash/etc.)
|
||||
├── references/ - Documentation intended to be loaded into context as needed
|
||||
└── assets/ - Files used in output (templates, icons, fonts, etc.)
|
||||
```
|
||||
|
||||
### SKILL.md Components
|
||||
|
||||
**Frontmatter (YAML)** - Required fields:
|
||||
|
||||
- **name**: Skill name in hyphen-case (e.g., `weather-api`, `pdf-editor`)
|
||||
- **description**: **CRITICAL** - Primary triggering mechanism
|
||||
- Must clearly describe what the skill does
|
||||
- Must explicitly state when to use it
|
||||
- Include specific trigger scenarios and keywords
|
||||
- All "when to use" info goes here, NOT in body
|
||||
- Example: `"PDF document processing with rotation, merging, splitting, and text extraction. Use when user needs to: (1) Rotate PDF pages, (2) Merge multiple PDFs, (3) Split PDF files, (4) Extract text from PDFs."`
|
||||
|
||||
**Body (Markdown)** - Loaded after skill triggers:
|
||||
|
||||
- Detailed usage instructions
|
||||
- How to call scripts and read references
|
||||
- Examples and best practices
|
||||
- Use imperative/infinitive form ("Use X to do Y")
|
||||
|
||||
### Bundled Resources
|
||||
|
||||
**scripts/** - When to include:
|
||||
- Code is repeatedly rewritten
|
||||
- Deterministic execution needed (avoid LLM randomness)
|
||||
- Examples: PDF rotation, image processing
|
||||
- Must test scripts before including
|
||||
|
||||
**references/** - When to include:
|
||||
- Documentation for agent to reference
|
||||
- Database schemas, API docs, domain knowledge
|
||||
- Agent reads these files into context as needed
|
||||
- For large files (>10k words), include grep patterns in SKILL.md
|
||||
|
||||
**assets/** - When to include:
|
||||
- Files used in output (not loaded to context)
|
||||
- Templates, icons, boilerplate code
|
||||
- Copied or modified in final output
|
||||
|
||||
**Important**: Most skills don't need all three. Choose based on actual needs.
|
||||
|
||||
### What NOT to Include
|
||||
|
||||
Do NOT create auxiliary documentation:
|
||||
- README.md
|
||||
- INSTALLATION_GUIDE.md
|
||||
- CHANGELOG.md
|
||||
- Other non-essential files
|
||||
|
||||
## Skill Creation Process
|
||||
|
||||
**COW Simplified Version** - Skills are used locally, no packaging/sharing needed.
|
||||
|
||||
1. **Understand** - Clarify use cases with concrete examples
|
||||
2. **Plan** - Identify needed scripts, references, assets
|
||||
3. **Initialize** - Run init_skill.py to create template
|
||||
4. **Edit** - Implement SKILL.md and resources
|
||||
5. **Validate** (optional) - Run quick_validate.py to check format
|
||||
6. **Iterate** - Improve based on real usage
|
||||
|
||||
## Skill Naming
|
||||
|
||||
- Use lowercase letters, digits, and hyphens only; normalize user-provided titles to hyphen-case (e.g., "Plan Mode" -> `plan-mode`).
|
||||
- When generating names, generate a name under 64 characters (letters, digits, hyphens).
|
||||
- Prefer short, verb-led phrases that describe the action.
|
||||
- Namespace by tool when it improves clarity or triggering (e.g., `gh-address-comments`, `linear-address-issue`).
|
||||
- Name the skill folder exactly after the skill name.
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
### Step 1: Understanding the Skill with Concrete Examples
|
||||
|
||||
Skip this step only when the skill's usage patterns are already clearly understood. It remains valuable even when working with an existing skill.
|
||||
|
||||
To create an effective skill, clearly understand concrete examples of how the skill will be used. This understanding can come from either direct user examples or generated examples that are validated with user feedback.
|
||||
|
||||
For example, when building an image-editor skill, relevant questions include:
|
||||
|
||||
- "What functionality should the image-editor skill support? Editing, rotating, anything else?"
|
||||
- "Can you give some examples of how this skill would be used?"
|
||||
- "I can imagine users asking for things like 'Remove the red-eye from this image' or 'Rotate this image'. Are there other ways you imagine this skill being used?"
|
||||
- "What would a user say that should trigger this skill?"
|
||||
|
||||
To avoid overwhelming users, avoid asking too many questions in a single message. Start with the most important questions and follow up as needed for better effectiveness.
|
||||
|
||||
Conclude this step when there is a clear sense of the functionality the skill should support.
|
||||
|
||||
### Step 2: Planning the Reusable Skill Contents
|
||||
|
||||
To turn concrete examples into an effective skill, analyze each example by:
|
||||
|
||||
1. Considering how to execute on the example from scratch
|
||||
2. Identifying what scripts, references, and assets would be helpful when executing these workflows repeatedly
|
||||
|
||||
Example: When building a `pdf-editor` skill to handle queries like "Help me rotate this PDF," the analysis shows:
|
||||
|
||||
1. Rotating a PDF requires re-writing the same code each time
|
||||
2. A `scripts/rotate_pdf.py` script would be helpful to store in the skill
|
||||
|
||||
Example: When designing a `frontend-webapp-builder` skill for queries like "Build me a todo app" or "Build me a dashboard to track my steps," the analysis shows:
|
||||
|
||||
1. Writing a frontend webapp requires the same boilerplate HTML/React each time
|
||||
2. An `assets/hello-world/` template containing the boilerplate HTML/React project files would be helpful to store in the skill
|
||||
|
||||
Example: When building a `big-query` skill to handle queries like "How many users have logged in today?" the analysis shows:
|
||||
|
||||
1. Querying BigQuery requires re-discovering the table schemas and relationships each time
|
||||
2. A `references/schema.md` file documenting the table schemas would be helpful to store in the skill
|
||||
|
||||
To establish the skill's contents, analyze each concrete example to create a list of the reusable resources to include: scripts, references, and assets.
|
||||
|
||||
### Step 3: Initialize the Skill
|
||||
|
||||
At this point, it is time to actually create the skill.
|
||||
|
||||
Skip this step only if the skill being developed already exists, and iteration is needed. In this case, continue to the next step.
|
||||
|
||||
When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable.
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
scripts/init_skill.py <skill-name> --path <output-directory> [--resources scripts,references,assets] [--examples]
|
||||
```
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
scripts/init_skill.py my-skill --path ~/cow/skills
|
||||
scripts/init_skill.py my-skill --path ~/cow/skills --resources scripts,references
|
||||
scripts/init_skill.py my-skill --path ~/cow/skills --resources scripts --examples
|
||||
```
|
||||
|
||||
The script:
|
||||
|
||||
- Creates the skill directory at the specified path
|
||||
- Generates a SKILL.md template with proper frontmatter and TODO placeholders
|
||||
- Optionally creates resource directories based on `--resources`
|
||||
- Optionally adds example files when `--examples` is set
|
||||
|
||||
After initialization, customize the SKILL.md and add resources as needed. If you used `--examples`, replace or delete placeholder files.
|
||||
|
||||
**Important**: Always create skills in workspace directory (`~/cow/skills`), NOT in project directory.
|
||||
|
||||
### Step 4: Edit the Skill
|
||||
|
||||
When editing the (newly-generated or existing) skill, remember that the skill is being created for another instance of the agent to use. Include information that would be beneficial and non-obvious to the agent. Consider what procedural knowledge, domain-specific details, or reusable assets would help another agent instance execute these tasks more effectively.
|
||||
|
||||
#### Learn Proven Design Patterns
|
||||
|
||||
Consult these helpful guides based on your skill's needs:
|
||||
|
||||
- **Multi-step processes**: See references/workflows.md for sequential workflows and conditional logic
|
||||
- **Specific output formats or quality standards**: See references/output-patterns.md for template and example patterns
|
||||
|
||||
These files contain established best practices for effective skill design.
|
||||
|
||||
#### Start with Reusable Skill Contents
|
||||
|
||||
To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`.
|
||||
|
||||
Added scripts must be tested by actually running them to ensure there are no bugs and that the output matches what is expected. If there are many similar scripts, only a representative sample needs to be tested to ensure confidence that they all work while balancing time to completion.
|
||||
|
||||
If you used `--examples`, delete any placeholder files that are not needed for the skill. Only create resource directories that are actually required.
|
||||
|
||||
#### Update SKILL.md
|
||||
|
||||
**Writing Guidelines:** Always use imperative/infinitive form.
|
||||
|
||||
##### Frontmatter
|
||||
|
||||
Write the YAML frontmatter with `name` and `description`:
|
||||
|
||||
- `name`: The skill name
|
||||
- `description`: This is the primary triggering mechanism for your skill, and helps the agent understand when to use the skill.
|
||||
- Include both what the Skill does and specific triggers/contexts for when to use it.
|
||||
- Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to the agent.
|
||||
- Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when the agent needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks"
|
||||
|
||||
Do not include any other fields in YAML frontmatter.
|
||||
|
||||
##### Body
|
||||
|
||||
Write instructions for using the skill and its bundled resources.
|
||||
|
||||
### Step 5: Validate (Optional)
|
||||
|
||||
Validate skill format:
|
||||
|
||||
```bash
|
||||
scripts/quick_validate.py <path/to/skill-folder>
|
||||
```
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
scripts/quick_validate.py ~/cow/skills/weather-api
|
||||
```
|
||||
|
||||
Validation checks:
|
||||
- YAML frontmatter format and required fields
|
||||
- Skill naming conventions (hyphen-case, lowercase)
|
||||
- Description completeness and quality
|
||||
- File organization
|
||||
|
||||
**Note**: Validation is optional in COW. Mainly useful for troubleshooting format issues.
|
||||
|
||||
### Step 6: Iterate
|
||||
|
||||
Improve based on real usage:
|
||||
|
||||
1. Use skill on real tasks
|
||||
2. Notice struggles or inefficiencies
|
||||
3. Identify needed updates to SKILL.md or resources
|
||||
4. Implement changes and test again
|
||||
|
||||
## Progressive Disclosure
|
||||
|
||||
Skills use three-level loading:
|
||||
|
||||
1. **Metadata** (name + description) - Always in context (~100 words)
|
||||
2. **SKILL.md body** - Loaded when skill triggers (<5k words)
|
||||
3. **Resources** - Loaded as needed by agent
|
||||
|
||||
**Best practices**:
|
||||
- Keep SKILL.md under 500 lines
|
||||
- Split complex content into `references/` files
|
||||
- Reference these files clearly in SKILL.md
|
||||
|
||||
**Pattern**: For skills with multiple variants/frameworks:
|
||||
- Keep core workflow in SKILL.md
|
||||
- Move variant-specific details to separate reference files
|
||||
- Agent loads only relevant files
|
||||
|
||||
Example:
|
||||
```
|
||||
cloud-deploy/
|
||||
├── SKILL.md (workflow + provider selection)
|
||||
└── references/
|
||||
├── aws.md
|
||||
├── gcp.md
|
||||
└── azure.md
|
||||
```
|
||||
|
||||
When user chooses AWS, agent only reads aws.md.
|
||||
82
skills/skill-creator/references/output-patterns.md
Normal file
82
skills/skill-creator/references/output-patterns.md
Normal file
@@ -0,0 +1,82 @@
|
||||
# Output Patterns
|
||||
|
||||
Use these patterns when skills need to produce consistent, high-quality output.
|
||||
|
||||
## Template Pattern
|
||||
|
||||
Provide templates for output format. Match the level of strictness to your needs.
|
||||
|
||||
**For strict requirements (like API responses or data formats):**
|
||||
|
||||
```markdown
|
||||
## Report structure
|
||||
|
||||
ALWAYS use this exact template structure:
|
||||
|
||||
# [Analysis Title]
|
||||
|
||||
## Executive summary
|
||||
[One-paragraph overview of key findings]
|
||||
|
||||
## Key findings
|
||||
- Finding 1 with supporting data
|
||||
- Finding 2 with supporting data
|
||||
- Finding 3 with supporting data
|
||||
|
||||
## Recommendations
|
||||
1. Specific actionable recommendation
|
||||
2. Specific actionable recommendation
|
||||
```
|
||||
|
||||
**For flexible guidance (when adaptation is useful):**
|
||||
|
||||
```markdown
|
||||
## Report structure
|
||||
|
||||
Here is a sensible default format, but use your best judgment:
|
||||
|
||||
# [Analysis Title]
|
||||
|
||||
## Executive summary
|
||||
[Overview]
|
||||
|
||||
## Key findings
|
||||
[Adapt sections based on what you discover]
|
||||
|
||||
## Recommendations
|
||||
[Tailor to the specific context]
|
||||
|
||||
Adjust sections as needed for the specific analysis type.
|
||||
```
|
||||
|
||||
## Examples Pattern
|
||||
|
||||
For skills where output quality depends on seeing examples, provide input/output pairs:
|
||||
|
||||
```markdown
|
||||
## Commit message format
|
||||
|
||||
Generate commit messages following these examples:
|
||||
|
||||
**Example 1:**
|
||||
Input: Added user authentication with JWT tokens
|
||||
Output:
|
||||
```
|
||||
feat(auth): implement JWT-based authentication
|
||||
|
||||
Add login endpoint and token validation middleware
|
||||
```
|
||||
|
||||
**Example 2:**
|
||||
Input: Fixed bug where dates displayed incorrectly in reports
|
||||
Output:
|
||||
```
|
||||
fix(reports): correct date formatting in timezone conversion
|
||||
|
||||
Use UTC timestamps consistently across report generation
|
||||
```
|
||||
|
||||
Follow this style: type(scope): brief description, then detailed explanation.
|
||||
```
|
||||
|
||||
Examples help Claude understand the desired style and level of detail more clearly than descriptions alone.
|
||||
28
skills/skill-creator/references/workflows.md
Normal file
28
skills/skill-creator/references/workflows.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# Workflow Patterns
|
||||
|
||||
## Sequential Workflows
|
||||
|
||||
For complex tasks, break operations into clear, sequential steps. It is often helpful to give Claude an overview of the process towards the beginning of SKILL.md:
|
||||
|
||||
```markdown
|
||||
Filling a PDF form involves these steps:
|
||||
|
||||
1. Analyze the form (run analyze_form.py)
|
||||
2. Create field mapping (edit fields.json)
|
||||
3. Validate mapping (run validate_fields.py)
|
||||
4. Fill the form (run fill_form.py)
|
||||
5. Verify output (run verify_output.py)
|
||||
```
|
||||
|
||||
## Conditional Workflows
|
||||
|
||||
For tasks with branching logic, guide Claude through decision points:
|
||||
|
||||
```markdown
|
||||
1. Determine the modification type:
|
||||
**Creating new content?** → Follow "Creation workflow" below
|
||||
**Editing existing content?** → Follow "Editing workflow" below
|
||||
|
||||
2. Creation workflow: [steps]
|
||||
3. Editing workflow: [steps]
|
||||
```
|
||||
303
skills/skill-creator/scripts/init_skill.py
Executable file
303
skills/skill-creator/scripts/init_skill.py
Executable file
@@ -0,0 +1,303 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Skill Initializer - Creates a new skill from template
|
||||
|
||||
Usage:
|
||||
init_skill.py <skill-name> --path <path>
|
||||
|
||||
Examples:
|
||||
init_skill.py my-new-skill --path skills/public
|
||||
init_skill.py my-api-helper --path skills/private
|
||||
init_skill.py custom-skill --path /custom/location
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SKILL_TEMPLATE = """---
|
||||
name: {skill_name}
|
||||
description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.]
|
||||
---
|
||||
|
||||
# {skill_title}
|
||||
|
||||
## Overview
|
||||
|
||||
[TODO: 1-2 sentences explaining what this skill enables]
|
||||
|
||||
## Structuring This Skill
|
||||
|
||||
[TODO: Choose the structure that best fits this skill's purpose. Common patterns:
|
||||
|
||||
**1. Workflow-Based** (best for sequential processes)
|
||||
- Works well when there are clear step-by-step procedures
|
||||
- Example: DOCX skill with "Workflow Decision Tree" → "Reading" → "Creating" → "Editing"
|
||||
- Structure: ## Overview → ## Workflow Decision Tree → ## Step 1 → ## Step 2...
|
||||
|
||||
**2. Task-Based** (best for tool collections)
|
||||
- Works well when the skill offers different operations/capabilities
|
||||
- Example: PDF skill with "Quick Start" → "Merge PDFs" → "Split PDFs" → "Extract Text"
|
||||
- Structure: ## Overview → ## Quick Start → ## Task Category 1 → ## Task Category 2...
|
||||
|
||||
**3. Reference/Guidelines** (best for standards or specifications)
|
||||
- Works well for brand guidelines, coding standards, or requirements
|
||||
- Example: Brand styling with "Brand Guidelines" → "Colors" → "Typography" → "Features"
|
||||
- Structure: ## Overview → ## Guidelines → ## Specifications → ## Usage...
|
||||
|
||||
**4. Capabilities-Based** (best for integrated systems)
|
||||
- Works well when the skill provides multiple interrelated features
|
||||
- Example: Product Management with "Core Capabilities" → numbered capability list
|
||||
- Structure: ## Overview → ## Core Capabilities → ### 1. Feature → ### 2. Feature...
|
||||
|
||||
Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations).
|
||||
|
||||
Delete this entire "Structuring This Skill" section when done - it's just guidance.]
|
||||
|
||||
## [TODO: Replace with the first main section based on chosen structure]
|
||||
|
||||
[TODO: Add content here. See examples in existing skills:
|
||||
- Code samples for technical skills
|
||||
- Decision trees for complex workflows
|
||||
- Concrete examples with realistic user requests
|
||||
- References to scripts/templates/references as needed]
|
||||
|
||||
## Resources
|
||||
|
||||
This skill includes example resource directories that demonstrate how to organize different types of bundled resources:
|
||||
|
||||
### scripts/
|
||||
Executable code (Python/Bash/etc.) that can be run directly to perform specific operations.
|
||||
|
||||
**Examples from other skills:**
|
||||
- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation
|
||||
- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing
|
||||
|
||||
**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations.
|
||||
|
||||
**Note:** Scripts may be executed without loading into context, but can still be read by Claude for patching or environment adjustments.
|
||||
|
||||
### references/
|
||||
Documentation and reference material intended to be loaded into context to inform Claude's process and thinking.
|
||||
|
||||
**Examples from other skills:**
|
||||
- Product management: `communication.md`, `context_building.md` - detailed workflow guides
|
||||
- BigQuery: API reference documentation and query examples
|
||||
- Finance: Schema documentation, company policies
|
||||
|
||||
**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Claude should reference while working.
|
||||
|
||||
### assets/
|
||||
Files not intended to be loaded into context, but rather used within the output Claude produces.
|
||||
|
||||
**Examples from other skills:**
|
||||
- Brand styling: PowerPoint template files (.pptx), logo files
|
||||
- Frontend builder: HTML/React boilerplate project directories
|
||||
- Typography: Font files (.ttf, .woff2)
|
||||
|
||||
**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output.
|
||||
|
||||
---
|
||||
|
||||
**Any unneeded directories can be deleted.** Not every skill requires all three types of resources.
|
||||
"""
|
||||
|
||||
EXAMPLE_SCRIPT = '''#!/usr/bin/env python3
|
||||
"""
|
||||
Example helper script for {skill_name}
|
||||
|
||||
This is a placeholder script that can be executed directly.
|
||||
Replace with actual implementation or delete if not needed.
|
||||
|
||||
Example real scripts from other skills:
|
||||
- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields
|
||||
- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images
|
||||
"""
|
||||
|
||||
def main():
|
||||
print("This is an example script for {skill_name}")
|
||||
# TODO: Add actual script logic here
|
||||
# This could be data processing, file conversion, API calls, etc.
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title}
|
||||
|
||||
This is a placeholder for detailed reference documentation.
|
||||
Replace with actual reference content or delete if not needed.
|
||||
|
||||
Example real reference docs from other skills:
|
||||
- product-management/references/communication.md - Comprehensive guide for status updates
|
||||
- product-management/references/context_building.md - Deep-dive on gathering context
|
||||
- bigquery/references/ - API references and query examples
|
||||
|
||||
## When Reference Docs Are Useful
|
||||
|
||||
Reference docs are ideal for:
|
||||
- Comprehensive API documentation
|
||||
- Detailed workflow guides
|
||||
- Complex multi-step processes
|
||||
- Information too lengthy for main SKILL.md
|
||||
- Content that's only needed for specific use cases
|
||||
|
||||
## Structure Suggestions
|
||||
|
||||
### API Reference Example
|
||||
- Overview
|
||||
- Authentication
|
||||
- Endpoints with examples
|
||||
- Error codes
|
||||
- Rate limits
|
||||
|
||||
### Workflow Guide Example
|
||||
- Prerequisites
|
||||
- Step-by-step instructions
|
||||
- Common patterns
|
||||
- Troubleshooting
|
||||
- Best practices
|
||||
"""
|
||||
|
||||
EXAMPLE_ASSET = """# Example Asset File
|
||||
|
||||
This placeholder represents where asset files would be stored.
|
||||
Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed.
|
||||
|
||||
Asset files are NOT intended to be loaded into context, but rather used within
|
||||
the output Claude produces.
|
||||
|
||||
Example asset files from other skills:
|
||||
- Brand guidelines: logo.png, slides_template.pptx
|
||||
- Frontend builder: hello-world/ directory with HTML/React boilerplate
|
||||
- Typography: custom-font.ttf, font-family.woff2
|
||||
- Data: sample_data.csv, test_dataset.json
|
||||
|
||||
## Common Asset Types
|
||||
|
||||
- Templates: .pptx, .docx, boilerplate directories
|
||||
- Images: .png, .jpg, .svg, .gif
|
||||
- Fonts: .ttf, .otf, .woff, .woff2
|
||||
- Boilerplate code: Project directories, starter files
|
||||
- Icons: .ico, .svg
|
||||
- Data files: .csv, .json, .xml, .yaml
|
||||
|
||||
Note: This is a text placeholder. Actual assets can be any file type.
|
||||
"""
|
||||
|
||||
|
||||
def title_case_skill_name(skill_name):
|
||||
"""Convert hyphenated skill name to Title Case for display."""
|
||||
return ' '.join(word.capitalize() for word in skill_name.split('-'))
|
||||
|
||||
|
||||
def init_skill(skill_name, path):
|
||||
"""
|
||||
Initialize a new skill directory with template SKILL.md.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
path: Path where the skill directory should be created
|
||||
|
||||
Returns:
|
||||
Path to created skill directory, or None if error
|
||||
"""
|
||||
# Determine skill directory path
|
||||
skill_dir = Path(path).resolve() / skill_name
|
||||
|
||||
# Check if directory already exists
|
||||
if skill_dir.exists():
|
||||
print(f"❌ Error: Skill directory already exists: {skill_dir}")
|
||||
return None
|
||||
|
||||
# Create skill directory
|
||||
try:
|
||||
skill_dir.mkdir(parents=True, exist_ok=False)
|
||||
print(f"✅ Created skill directory: {skill_dir}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating directory: {e}")
|
||||
return None
|
||||
|
||||
# Create SKILL.md from template
|
||||
skill_title = title_case_skill_name(skill_name)
|
||||
skill_content = SKILL_TEMPLATE.format(
|
||||
skill_name=skill_name,
|
||||
skill_title=skill_title
|
||||
)
|
||||
|
||||
skill_md_path = skill_dir / 'SKILL.md'
|
||||
try:
|
||||
skill_md_path.write_text(skill_content)
|
||||
print("✅ Created SKILL.md")
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating SKILL.md: {e}")
|
||||
return None
|
||||
|
||||
# Create resource directories with example files
|
||||
try:
|
||||
# Create scripts/ directory with example script
|
||||
scripts_dir = skill_dir / 'scripts'
|
||||
scripts_dir.mkdir(exist_ok=True)
|
||||
example_script = scripts_dir / 'example.py'
|
||||
example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name))
|
||||
example_script.chmod(0o755)
|
||||
print("✅ Created scripts/example.py")
|
||||
|
||||
# Create references/ directory with example reference doc
|
||||
references_dir = skill_dir / 'references'
|
||||
references_dir.mkdir(exist_ok=True)
|
||||
example_reference = references_dir / 'api_reference.md'
|
||||
example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title))
|
||||
print("✅ Created references/api_reference.md")
|
||||
|
||||
# Create assets/ directory with example asset placeholder
|
||||
assets_dir = skill_dir / 'assets'
|
||||
assets_dir.mkdir(exist_ok=True)
|
||||
example_asset = assets_dir / 'example_asset.txt'
|
||||
example_asset.write_text(EXAMPLE_ASSET)
|
||||
print("✅ Created assets/example_asset.txt")
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating resource directories: {e}")
|
||||
return None
|
||||
|
||||
# Print next steps
|
||||
print(f"\n✅ Skill '{skill_name}' initialized successfully at {skill_dir}")
|
||||
print("\nNext steps:")
|
||||
print("1. Edit SKILL.md to complete the TODO items and update the description")
|
||||
print("2. Customize or delete the example files in scripts/, references/, and assets/")
|
||||
print("3. Run the validator when ready to check the skill structure")
|
||||
|
||||
return skill_dir
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 4 or sys.argv[2] != '--path':
|
||||
print("Usage: init_skill.py <skill-name> --path <path>")
|
||||
print("\nSkill name requirements:")
|
||||
print(" - Hyphen-case identifier (e.g., 'data-analyzer')")
|
||||
print(" - Lowercase letters, digits, and hyphens only")
|
||||
print(" - Max 40 characters")
|
||||
print(" - Must match directory name exactly")
|
||||
print("\nExamples:")
|
||||
print(" init_skill.py my-new-skill --path workspace/skills")
|
||||
print(" init_skill.py my-api-helper --path /path/to/skills")
|
||||
print(" init_skill.py custom-skill --path /custom/location")
|
||||
sys.exit(1)
|
||||
|
||||
skill_name = sys.argv[1]
|
||||
path = sys.argv[3]
|
||||
|
||||
print(f"🚀 Initializing skill: {skill_name}")
|
||||
print(f" Location: {path}")
|
||||
print()
|
||||
|
||||
result = init_skill(skill_name, path)
|
||||
|
||||
if result:
|
||||
sys.exit(0)
|
||||
else:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
116
skills/skill-creator/scripts/package_skill.py
Executable file
116
skills/skill-creator/scripts/package_skill.py
Executable file
@@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Skill Packager - Creates a distributable .skill file of a skill folder
|
||||
|
||||
Usage:
|
||||
python utils/package_skill.py <path/to/skill-folder> [output-directory]
|
||||
|
||||
Example:
|
||||
python utils/package_skill.py skills/public/my-skill
|
||||
python utils/package_skill.py skills/public/my-skill ./dist
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
# Add script directory to path for imports
|
||||
script_dir = Path(__file__).parent
|
||||
sys.path.insert(0, str(script_dir))
|
||||
|
||||
from quick_validate import validate_skill
|
||||
|
||||
|
||||
def package_skill(skill_path, output_dir=None):
|
||||
"""
|
||||
Package a skill folder into a .skill file.
|
||||
|
||||
Args:
|
||||
skill_path: Path to the skill folder
|
||||
output_dir: Optional output directory for the .skill file (defaults to current directory)
|
||||
|
||||
Returns:
|
||||
Path to the created .skill file, or None if error
|
||||
"""
|
||||
skill_path = Path(skill_path).resolve()
|
||||
|
||||
# Validate skill folder exists
|
||||
if not skill_path.exists():
|
||||
print(f"❌ Error: Skill folder not found: {skill_path}")
|
||||
return None
|
||||
|
||||
if not skill_path.is_dir():
|
||||
print(f"❌ Error: Path is not a directory: {skill_path}")
|
||||
return None
|
||||
|
||||
# Validate SKILL.md exists
|
||||
skill_md = skill_path / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
print(f"❌ Error: SKILL.md not found in {skill_path}")
|
||||
return None
|
||||
|
||||
# Run validation before packaging
|
||||
print("🔍 Validating skill...")
|
||||
valid, message = validate_skill(skill_path)
|
||||
if not valid:
|
||||
print(f"❌ Validation failed: {message}")
|
||||
print(" Please fix the validation errors before packaging.")
|
||||
return None
|
||||
print(f"✅ {message}\n")
|
||||
|
||||
# Determine output location
|
||||
skill_name = skill_path.name
|
||||
if output_dir:
|
||||
output_path = Path(output_dir).resolve()
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
output_path = Path.cwd()
|
||||
|
||||
skill_filename = output_path / f"{skill_name}.skill"
|
||||
|
||||
# Create the .skill file (zip format)
|
||||
try:
|
||||
with zipfile.ZipFile(skill_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
# Walk through the skill directory
|
||||
for file_path in skill_path.rglob('*'):
|
||||
if file_path.is_file():
|
||||
# Calculate the relative path within the zip
|
||||
arcname = file_path.relative_to(skill_path.parent)
|
||||
zipf.write(file_path, arcname)
|
||||
print(f" Added: {arcname}")
|
||||
|
||||
print(f"\n✅ Successfully packaged skill to: {skill_filename}")
|
||||
return skill_filename
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating .skill file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python utils/package_skill.py <path/to/skill-folder> [output-directory]")
|
||||
print("\nExample:")
|
||||
print(" python utils/package_skill.py skills/public/my-skill")
|
||||
print(" python utils/package_skill.py skills/public/my-skill ./dist")
|
||||
sys.exit(1)
|
||||
|
||||
skill_path = sys.argv[1]
|
||||
output_dir = sys.argv[2] if len(sys.argv) > 2 else None
|
||||
|
||||
print(f"📦 Packaging skill: {skill_path}")
|
||||
if output_dir:
|
||||
print(f" Output directory: {output_dir}")
|
||||
print()
|
||||
|
||||
result = package_skill(skill_path, output_dir)
|
||||
|
||||
if result:
|
||||
sys.exit(0)
|
||||
else:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
95
skills/skill-creator/scripts/quick_validate.py
Executable file
95
skills/skill-creator/scripts/quick_validate.py
Executable file
@@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick validation script for skills - minimal version
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
def validate_skill(skill_path):
|
||||
"""Basic validation of a skill"""
|
||||
skill_path = Path(skill_path)
|
||||
|
||||
# Check SKILL.md exists
|
||||
skill_md = skill_path / 'SKILL.md'
|
||||
if not skill_md.exists():
|
||||
return False, "SKILL.md not found"
|
||||
|
||||
# Read and validate frontmatter
|
||||
content = skill_md.read_text()
|
||||
if not content.startswith('---'):
|
||||
return False, "No YAML frontmatter found"
|
||||
|
||||
# Extract frontmatter
|
||||
match = re.match(r'^---\n(.*?)\n---', content, re.DOTALL)
|
||||
if not match:
|
||||
return False, "Invalid frontmatter format"
|
||||
|
||||
frontmatter_text = match.group(1)
|
||||
|
||||
# Parse YAML frontmatter
|
||||
try:
|
||||
frontmatter = yaml.safe_load(frontmatter_text)
|
||||
if not isinstance(frontmatter, dict):
|
||||
return False, "Frontmatter must be a YAML dictionary"
|
||||
except yaml.YAMLError as e:
|
||||
return False, f"Invalid YAML in frontmatter: {e}"
|
||||
|
||||
# Define allowed properties
|
||||
ALLOWED_PROPERTIES = {'name', 'description', 'license', 'allowed-tools', 'metadata'}
|
||||
|
||||
# Check for unexpected properties (excluding nested keys under metadata)
|
||||
unexpected_keys = set(frontmatter.keys()) - ALLOWED_PROPERTIES
|
||||
if unexpected_keys:
|
||||
return False, (
|
||||
f"Unexpected key(s) in SKILL.md frontmatter: {', '.join(sorted(unexpected_keys))}. "
|
||||
f"Allowed properties are: {', '.join(sorted(ALLOWED_PROPERTIES))}"
|
||||
)
|
||||
|
||||
# Check required fields
|
||||
if 'name' not in frontmatter:
|
||||
return False, "Missing 'name' in frontmatter"
|
||||
if 'description' not in frontmatter:
|
||||
return False, "Missing 'description' in frontmatter"
|
||||
|
||||
# Extract name for validation
|
||||
name = frontmatter.get('name', '')
|
||||
if not isinstance(name, str):
|
||||
return False, f"Name must be a string, got {type(name).__name__}"
|
||||
name = name.strip()
|
||||
if name:
|
||||
# Check naming convention (hyphen-case: lowercase with hyphens)
|
||||
if not re.match(r'^[a-z0-9-]+$', name):
|
||||
return False, f"Name '{name}' should be hyphen-case (lowercase letters, digits, and hyphens only)"
|
||||
if name.startswith('-') or name.endswith('-') or '--' in name:
|
||||
return False, f"Name '{name}' cannot start/end with hyphen or contain consecutive hyphens"
|
||||
# Check name length (max 64 characters per spec)
|
||||
if len(name) > 64:
|
||||
return False, f"Name is too long ({len(name)} characters). Maximum is 64 characters."
|
||||
|
||||
# Extract and validate description
|
||||
description = frontmatter.get('description', '')
|
||||
if not isinstance(description, str):
|
||||
return False, f"Description must be a string, got {type(description).__name__}"
|
||||
description = description.strip()
|
||||
if description:
|
||||
# Check for angle brackets
|
||||
if '<' in description or '>' in description:
|
||||
return False, "Description cannot contain angle brackets (< or >)"
|
||||
# Check description length (max 1024 characters per spec)
|
||||
if len(description) > 1024:
|
||||
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters."
|
||||
|
||||
return True, "Skill is valid!"
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python quick_validate.py <skill_directory>")
|
||||
sys.exit(1)
|
||||
|
||||
valid, message = validate_skill(sys.argv[1])
|
||||
print(message)
|
||||
sys.exit(0 if valid else 1)
|
||||
Reference in New Issue
Block a user