From bb850bb6c5d542ff05a395e99a74ac4a9b2c46f7 Mon Sep 17 00:00:00 2001 From: saboteur7 Date: Fri, 30 Jan 2026 09:53:46 +0800 Subject: [PATCH] feat: personal ai agent framework --- .gitignore | 3 + agent/memory/__init__.py | 10 + agent/memory/chunker.py | 139 ++++ agent/memory/config.py | 114 +++ agent/memory/embedding.py | 175 +++++ agent/memory/manager.py | 623 +++++++++++++++++ agent/memory/storage.py | 418 +++++++++++ agent/memory/summarizer.py | 235 +++++++ agent/memory/tools/__init__.py | 10 + agent/memory/tools/memory_get.py | 118 ++++ agent/memory/tools/memory_search.py | 106 +++ agent/protocol/__init__.py | 20 + agent/protocol/agent.py | 292 ++++++++ agent/protocol/agent_stream.py | 461 ++++++++++++ agent/protocol/context.py | 27 + agent/protocol/models.py | 57 ++ agent/protocol/result.py | 96 +++ agent/protocol/task.py | 95 +++ agent/tools/__init__.py | 101 +++ agent/tools/base_tool.py | 99 +++ agent/tools/bash/__init__.py | 3 + agent/tools/bash/bash.py | 187 +++++ agent/tools/browser/browser_action.py | 59 ++ agent/tools/browser/browser_tool.py | 317 +++++++++ agent/tools/browser_tool.py | 18 + agent/tools/calculator/calculator.py | 58 ++ agent/tools/current_time/current_time.py | 75 ++ agent/tools/edit/__init__.py | 3 + agent/tools/edit/edit.py | 164 +++++ agent/tools/file_save/__init__.py | 3 + agent/tools/file_save/file_save.py | 770 +++++++++++++++++++++ agent/tools/find/__init__.py | 3 + agent/tools/find/find.py | 177 +++++ agent/tools/google_search/google_search.py | 48 ++ agent/tools/grep/__init__.py | 3 + agent/tools/grep/grep.py | 248 +++++++ agent/tools/ls/__init__.py | 3 + agent/tools/ls/ls.py | 125 ++++ agent/tools/memory/__init__.py | 10 + agent/tools/memory/memory_get.py | 107 +++ agent/tools/memory/memory_search.py | 96 +++ agent/tools/read/__init__.py | 3 + agent/tools/read/read.py | 336 +++++++++ agent/tools/terminal/__init__.py | 3 + agent/tools/terminal/terminal.py | 100 +++ agent/tools/tool_manager.py | 208 ++++++ agent/tools/utils/__init__.py | 40 ++ agent/tools/utils/diff.py | 167 +++++ agent/tools/utils/truncate.py | 292 ++++++++ agent/tools/write/__init__.py | 3 + agent/tools/write/write.py | 91 +++ bot/claude/claude_ai_bot.py | 222 ------ bot/claude/claude_ai_session.py | 9 - bot/claudeapi/claude_api_bot.py | 399 ++++++++++- bridge/agent_bridge.py | 288 ++++++++ bridge/bridge.py | 29 +- channel/channel.py | 27 +- channel/web/web_channel.py | 22 +- config-template.json | 6 +- config.py | 3 +- memory/2026-01-29.md | 5 + memory/MEMORY.md | 21 + 62 files changed, 7675 insertions(+), 275 deletions(-) create mode 100644 agent/memory/__init__.py create mode 100644 agent/memory/chunker.py create mode 100644 agent/memory/config.py create mode 100644 agent/memory/embedding.py create mode 100644 agent/memory/manager.py create mode 100644 agent/memory/storage.py create mode 100644 agent/memory/summarizer.py create mode 100644 agent/memory/tools/__init__.py create mode 100644 agent/memory/tools/memory_get.py create mode 100644 agent/memory/tools/memory_search.py create mode 100644 agent/protocol/__init__.py create mode 100644 agent/protocol/agent.py create mode 100644 agent/protocol/agent_stream.py create mode 100644 agent/protocol/context.py create mode 100644 agent/protocol/models.py create mode 100644 agent/protocol/result.py create mode 100644 agent/protocol/task.py create mode 100644 agent/tools/__init__.py create mode 100644 agent/tools/base_tool.py create mode 100644 agent/tools/bash/__init__.py create mode 100644 agent/tools/bash/bash.py create mode 100644 agent/tools/browser/browser_action.py create mode 100644 agent/tools/browser/browser_tool.py create mode 100644 agent/tools/browser_tool.py create mode 100644 agent/tools/calculator/calculator.py create mode 100644 agent/tools/current_time/current_time.py create mode 100644 agent/tools/edit/__init__.py create mode 100644 agent/tools/edit/edit.py create mode 100644 agent/tools/file_save/__init__.py create mode 100644 agent/tools/file_save/file_save.py create mode 100644 agent/tools/find/__init__.py create mode 100644 agent/tools/find/find.py create mode 100644 agent/tools/google_search/google_search.py create mode 100644 agent/tools/grep/__init__.py create mode 100644 agent/tools/grep/grep.py create mode 100644 agent/tools/ls/__init__.py create mode 100644 agent/tools/ls/ls.py create mode 100644 agent/tools/memory/__init__.py create mode 100644 agent/tools/memory/memory_get.py create mode 100644 agent/tools/memory/memory_search.py create mode 100644 agent/tools/read/__init__.py create mode 100644 agent/tools/read/read.py create mode 100644 agent/tools/terminal/__init__.py create mode 100644 agent/tools/terminal/terminal.py create mode 100644 agent/tools/tool_manager.py create mode 100644 agent/tools/utils/__init__.py create mode 100644 agent/tools/utils/diff.py create mode 100644 agent/tools/utils/truncate.py create mode 100644 agent/tools/write/__init__.py create mode 100644 agent/tools/write/write.py delete mode 100644 bot/claude/claude_ai_bot.py delete mode 100644 bot/claude/claude_ai_session.py create mode 100644 bridge/agent_bridge.py create mode 100644 memory/2026-01-29.md create mode 100644 memory/MEMORY.md diff --git a/.gitignore b/.gitignore index b88dc49..8548ecd 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,6 @@ plugins/banwords/lib/__pycache__ !plugins/linkai !plugins/agent client_config.json +ref/ +.cursor/ +local/ diff --git a/agent/memory/__init__.py b/agent/memory/__init__.py new file mode 100644 index 0000000..4179bea --- /dev/null +++ b/agent/memory/__init__.py @@ -0,0 +1,10 @@ +""" +Memory module for AgentMesh + +Provides long-term memory capabilities with hybrid search (vector + keyword) +""" + +from agent.memory.manager import MemoryManager +from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config + +__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config'] diff --git a/agent/memory/chunker.py b/agent/memory/chunker.py new file mode 100644 index 0000000..b8f0c35 --- /dev/null +++ b/agent/memory/chunker.py @@ -0,0 +1,139 @@ +""" +Text chunking utilities for memory + +Splits text into chunks with token limits and overlap +""" + +from typing import List, Tuple +from dataclasses import dataclass + + +@dataclass +class TextChunk: + """Represents a text chunk with line numbers""" + text: str + start_line: int + end_line: int + + +class TextChunker: + """Chunks text by line count with token estimation""" + + def __init__(self, max_tokens: int = 500, overlap_tokens: int = 50): + """ + Initialize chunker + + Args: + max_tokens: Maximum tokens per chunk + overlap_tokens: Overlap tokens between chunks + """ + self.max_tokens = max_tokens + self.overlap_tokens = overlap_tokens + # Rough estimation: ~4 chars per token for English/Chinese mixed + self.chars_per_token = 4 + + def chunk_text(self, text: str) -> List[TextChunk]: + """ + Chunk text into overlapping segments + + Args: + text: Input text to chunk + + Returns: + List of TextChunk objects + """ + if not text.strip(): + return [] + + lines = text.split('\n') + chunks = [] + + max_chars = self.max_tokens * self.chars_per_token + overlap_chars = self.overlap_tokens * self.chars_per_token + + current_chunk = [] + current_chars = 0 + start_line = 1 + + for i, line in enumerate(lines, start=1): + line_chars = len(line) + + # If single line exceeds max, split it + if line_chars > max_chars: + # Save current chunk if exists + if current_chunk: + chunks.append(TextChunk( + text='\n'.join(current_chunk), + start_line=start_line, + end_line=i - 1 + )) + current_chunk = [] + current_chars = 0 + + # Split long line into multiple chunks + for sub_chunk in self._split_long_line(line, max_chars): + chunks.append(TextChunk( + text=sub_chunk, + start_line=i, + end_line=i + )) + + start_line = i + 1 + continue + + # Check if adding this line would exceed limit + if current_chars + line_chars > max_chars and current_chunk: + # Save current chunk + chunks.append(TextChunk( + text='\n'.join(current_chunk), + start_line=start_line, + end_line=i - 1 + )) + + # Start new chunk with overlap + overlap_lines = self._get_overlap_lines(current_chunk, overlap_chars) + current_chunk = overlap_lines + [line] + current_chars = sum(len(l) for l in current_chunk) + start_line = i - len(overlap_lines) + else: + # Add line to current chunk + current_chunk.append(line) + current_chars += line_chars + + # Save last chunk + if current_chunk: + chunks.append(TextChunk( + text='\n'.join(current_chunk), + start_line=start_line, + end_line=len(lines) + )) + + return chunks + + def _split_long_line(self, line: str, max_chars: int) -> List[str]: + """Split a single long line into multiple chunks""" + chunks = [] + for i in range(0, len(line), max_chars): + chunks.append(line[i:i + max_chars]) + return chunks + + def _get_overlap_lines(self, lines: List[str], target_chars: int) -> List[str]: + """Get last few lines that fit within target_chars for overlap""" + overlap = [] + chars = 0 + + for line in reversed(lines): + line_chars = len(line) + if chars + line_chars > target_chars: + break + overlap.insert(0, line) + chars += line_chars + + return overlap + + def chunk_markdown(self, text: str) -> List[TextChunk]: + """ + Chunk markdown text while respecting structure + (For future enhancement: respect markdown sections) + """ + return self.chunk_text(text) diff --git a/agent/memory/config.py b/agent/memory/config.py new file mode 100644 index 0000000..366c134 --- /dev/null +++ b/agent/memory/config.py @@ -0,0 +1,114 @@ +""" +Memory configuration module + +Provides global memory configuration with simplified workspace structure +""" + +import os +from dataclasses import dataclass, field +from typing import Optional, List +from pathlib import Path + + +@dataclass +class MemoryConfig: + """Configuration for memory storage and search""" + + # Storage paths (default: ~/cow) + workspace_root: str = field(default_factory=lambda: os.path.expanduser("~/cow")) + + # Embedding config + embedding_provider: str = "openai" # "openai" | "local" + embedding_model: str = "text-embedding-3-small" + embedding_dim: int = 1536 + + # Chunking config + chunk_max_tokens: int = 500 + chunk_overlap_tokens: int = 50 + + # Search config + max_results: int = 10 + min_score: float = 0.3 + + # Hybrid search weights + vector_weight: float = 0.7 + keyword_weight: float = 0.3 + + # Memory sources + sources: List[str] = field(default_factory=lambda: ["memory", "session"]) + + # Sync config + enable_auto_sync: bool = True + sync_on_search: bool = True + + def get_workspace(self) -> Path: + """Get workspace root directory""" + return Path(self.workspace_root) + + def get_memory_dir(self) -> Path: + """Get memory files directory""" + return self.get_workspace() / "memory" + + def get_db_path(self) -> Path: + """Get SQLite database path for long-term memory index""" + index_dir = self.get_memory_dir() / "long-term" + index_dir.mkdir(parents=True, exist_ok=True) + return index_dir / "index.db" + + def get_skills_dir(self) -> Path: + """Get skills directory""" + return self.get_workspace() / "skills" + + def get_agent_workspace(self, agent_name: Optional[str] = None) -> Path: + """ + Get workspace directory for an agent + + Args: + agent_name: Optional agent name (not used in current implementation) + + Returns: + Path to workspace directory + """ + workspace = self.get_workspace() + # Ensure workspace directory exists + workspace.mkdir(parents=True, exist_ok=True) + return workspace + + +# Global memory configuration +_global_memory_config: Optional[MemoryConfig] = None + + +def get_default_memory_config() -> MemoryConfig: + """ + Get the global memory configuration. + If not set, returns a default configuration. + + Returns: + MemoryConfig instance + """ + global _global_memory_config + if _global_memory_config is None: + _global_memory_config = MemoryConfig() + return _global_memory_config + + +def set_global_memory_config(config: MemoryConfig): + """ + Set the global memory configuration. + This should be called before creating any MemoryManager instances. + + Args: + config: MemoryConfig instance to use globally + + Example: + >>> from agent.memory import MemoryConfig, set_global_memory_config + >>> config = MemoryConfig( + ... workspace_root="~/my_agents", + ... embedding_provider="openai", + ... vector_weight=0.8 + ... ) + >>> set_global_memory_config(config) + """ + global _global_memory_config + _global_memory_config = config diff --git a/agent/memory/embedding.py b/agent/memory/embedding.py new file mode 100644 index 0000000..4a71828 --- /dev/null +++ b/agent/memory/embedding.py @@ -0,0 +1,175 @@ +""" +Embedding providers for memory + +Supports OpenAI and local embedding models +""" + +from typing import List, Optional +from abc import ABC, abstractmethod +import hashlib +import json + + +class EmbeddingProvider(ABC): + """Base class for embedding providers""" + + @abstractmethod + def embed(self, text: str) -> List[float]: + """Generate embedding for text""" + pass + + @abstractmethod + def embed_batch(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for multiple texts""" + pass + + @property + @abstractmethod + def dimensions(self) -> int: + """Get embedding dimensions""" + pass + + +class OpenAIEmbeddingProvider(EmbeddingProvider): + """OpenAI embedding provider""" + + def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None): + """ + Initialize OpenAI embedding provider + + Args: + model: Model name (text-embedding-3-small or text-embedding-3-large) + api_key: OpenAI API key + api_base: Optional API base URL + """ + self.model = model + self.api_key = api_key + self.api_base = api_base or "https://api.openai.com/v1" + + # Lazy import to avoid dependency issues + try: + from openai import OpenAI + self.client = OpenAI(api_key=api_key, base_url=api_base) + except ImportError: + raise ImportError("OpenAI package not installed. Install with: pip install openai") + + # Set dimensions based on model + self._dimensions = 1536 if "small" in model else 3072 + + def embed(self, text: str) -> List[float]: + """Generate embedding for text""" + response = self.client.embeddings.create( + input=text, + model=self.model + ) + return response.data[0].embedding + + def embed_batch(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for multiple texts""" + if not texts: + return [] + + response = self.client.embeddings.create( + input=texts, + model=self.model + ) + return [item.embedding for item in response.data] + + @property + def dimensions(self) -> int: + return self._dimensions + + +class LocalEmbeddingProvider(EmbeddingProvider): + """Local embedding provider using sentence-transformers""" + + def __init__(self, model: str = "all-MiniLM-L6-v2"): + """ + Initialize local embedding provider + + Args: + model: Model name from sentence-transformers + """ + self.model_name = model + + try: + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer(model) + self._dimensions = self.model.get_sentence_embedding_dimension() + except ImportError: + raise ImportError( + "sentence-transformers not installed. " + "Install with: pip install sentence-transformers" + ) + + def embed(self, text: str) -> List[float]: + """Generate embedding for text""" + embedding = self.model.encode(text, convert_to_numpy=True) + return embedding.tolist() + + def embed_batch(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for multiple texts""" + if not texts: + return [] + + embeddings = self.model.encode(texts, convert_to_numpy=True) + return embeddings.tolist() + + @property + def dimensions(self) -> int: + return self._dimensions + + +class EmbeddingCache: + """Cache for embeddings to avoid recomputation""" + + def __init__(self): + self.cache = {} + + def get(self, text: str, provider: str, model: str) -> Optional[List[float]]: + """Get cached embedding""" + key = self._compute_key(text, provider, model) + return self.cache.get(key) + + def put(self, text: str, provider: str, model: str, embedding: List[float]): + """Cache embedding""" + key = self._compute_key(text, provider, model) + self.cache[key] = embedding + + @staticmethod + def _compute_key(text: str, provider: str, model: str) -> str: + """Compute cache key""" + content = f"{provider}:{model}:{text}" + return hashlib.md5(content.encode('utf-8')).hexdigest() + + def clear(self): + """Clear cache""" + self.cache.clear() + + +def create_embedding_provider( + provider: str = "openai", + model: Optional[str] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None +) -> EmbeddingProvider: + """ + Factory function to create embedding provider + + Args: + provider: Provider name ("openai" or "local") + model: Model name (provider-specific) + api_key: API key for remote providers + api_base: API base URL for remote providers + + Returns: + EmbeddingProvider instance + """ + if provider == "openai": + model = model or "text-embedding-3-small" + return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base) + elif provider == "local": + model = model or "all-MiniLM-L6-v2" + return LocalEmbeddingProvider(model=model) + else: + raise ValueError(f"Unknown embedding provider: {provider}") diff --git a/agent/memory/manager.py b/agent/memory/manager.py new file mode 100644 index 0000000..58a135d --- /dev/null +++ b/agent/memory/manager.py @@ -0,0 +1,623 @@ +""" +Memory manager for AgentMesh + +Provides high-level interface for memory operations +""" + +import os +from typing import List, Optional, Dict, Any +from pathlib import Path +import hashlib +from datetime import datetime, timedelta + +from agent.memory.config import MemoryConfig, get_default_memory_config +from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult +from agent.memory.chunker import TextChunker +from agent.memory.embedding import create_embedding_provider, EmbeddingProvider +from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed + + +class MemoryManager: + """ + Memory manager with hybrid search capabilities + + Provides long-term memory for agents with vector and keyword search + """ + + def __init__( + self, + config: Optional[MemoryConfig] = None, + embedding_provider: Optional[EmbeddingProvider] = None, + llm_model: Optional[Any] = None + ): + """ + Initialize memory manager + + Args: + config: Memory configuration (uses global config if not provided) + embedding_provider: Custom embedding provider (optional) + llm_model: LLM model for summarization (optional) + """ + self.config = config or get_default_memory_config() + + # Initialize storage + db_path = self.config.get_db_path() + self.storage = MemoryStorage(db_path) + + # Initialize chunker + self.chunker = TextChunker( + max_tokens=self.config.chunk_max_tokens, + overlap_tokens=self.config.chunk_overlap_tokens + ) + + # Initialize embedding provider (optional) + self.embedding_provider = None + if embedding_provider: + self.embedding_provider = embedding_provider + else: + # Try to create embedding provider, but allow failure + try: + # Get API key from environment or config + api_key = os.environ.get('OPENAI_API_KEY') + api_base = os.environ.get('OPENAI_API_BASE') + + self.embedding_provider = create_embedding_provider( + provider=self.config.embedding_provider, + model=self.config.embedding_model, + api_key=api_key, + api_base=api_base + ) + except Exception as e: + # Embedding provider failed, but that's OK + # We can still use keyword search and file operations + print(f"⚠️ Warning: Embedding provider initialization failed: {e}") + print(f"ℹ️ Memory will work with keyword search only (no semantic search)") + + # Initialize memory flush manager + workspace_dir = self.config.get_workspace() + self.flush_manager = MemoryFlushManager( + workspace_dir=workspace_dir, + llm_model=llm_model + ) + + # Ensure workspace directories exist + self._init_workspace() + + self._dirty = False + + def _init_workspace(self): + """Initialize workspace directories""" + memory_dir = self.config.get_memory_dir() + memory_dir.mkdir(parents=True, exist_ok=True) + + # Create default memory files + workspace_dir = self.config.get_workspace() + create_memory_files_if_needed(workspace_dir) + + async def search( + self, + query: str, + user_id: Optional[str] = None, + max_results: Optional[int] = None, + min_score: Optional[float] = None, + include_shared: bool = True + ) -> List[SearchResult]: + """ + Search memory with hybrid search (vector + keyword) + + Args: + query: Search query + user_id: User ID for scoped search + max_results: Maximum results to return + min_score: Minimum score threshold + include_shared: Include shared memories + + Returns: + List of search results sorted by relevance + """ + max_results = max_results or self.config.max_results + min_score = min_score or self.config.min_score + + # Determine scopes + scopes = [] + if include_shared: + scopes.append("shared") + if user_id: + scopes.append("user") + + if not scopes: + return [] + + # Sync if needed + if self.config.sync_on_search and self._dirty: + await self.sync() + + # Perform vector search (if embedding provider available) + vector_results = [] + if self.embedding_provider: + query_embedding = self.embedding_provider.embed(query) + vector_results = self.storage.search_vector( + query_embedding=query_embedding, + user_id=user_id, + scopes=scopes, + limit=max_results * 2 # Get more candidates for merging + ) + + # Perform keyword search + keyword_results = self.storage.search_keyword( + query=query, + user_id=user_id, + scopes=scopes, + limit=max_results * 2 + ) + + # Merge results + merged = self._merge_results( + vector_results, + keyword_results, + self.config.vector_weight, + self.config.keyword_weight + ) + + # Filter by min score and limit + filtered = [r for r in merged if r.score >= min_score] + return filtered[:max_results] + + async def add_memory( + self, + content: str, + user_id: Optional[str] = None, + scope: str = "shared", + source: str = "memory", + path: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None + ): + """ + Add new memory content + + Args: + content: Memory content + user_id: User ID for user-scoped memory + scope: Memory scope ("shared", "user", "session") + source: Memory source ("memory" or "session") + path: File path (auto-generated if not provided) + metadata: Additional metadata + """ + if not content.strip(): + return + + # Generate path if not provided + if not path: + content_hash = hashlib.md5(content.encode('utf-8')).hexdigest()[:8] + if user_id and scope == "user": + path = f"memory/users/{user_id}/memory_{content_hash}.md" + else: + path = f"memory/shared/memory_{content_hash}.md" + + # Chunk content + chunks = self.chunker.chunk_text(content) + + # Generate embeddings (if provider available) + texts = [chunk.text for chunk in chunks] + if self.embedding_provider: + embeddings = self.embedding_provider.embed_batch(texts) + else: + # No embeddings, just use None + embeddings = [None] * len(texts) + + # Create memory chunks + memory_chunks = [] + for chunk, embedding in zip(chunks, embeddings): + chunk_id = self._generate_chunk_id(path, chunk.start_line, chunk.end_line) + chunk_hash = MemoryStorage.compute_hash(chunk.text) + + memory_chunks.append(MemoryChunk( + id=chunk_id, + agent_id="default", + user_id=user_id, + scope=scope, + source=source, + path=path, + start_line=chunk.start_line, + end_line=chunk.end_line, + text=chunk.text, + embedding=embedding, + hash=chunk_hash, + metadata=metadata + )) + + # Save to storage + self.storage.save_chunks_batch(memory_chunks) + + # Update file metadata + file_hash = MemoryStorage.compute_hash(content) + self.storage.update_file_metadata( + path=path, + source=source, + file_hash=file_hash, + mtime=int(os.path.getmtime(__file__)), # Use current time + size=len(content) + ) + + async def sync(self, force: bool = False): + """ + Synchronize memory from files + + Args: + force: Force full reindex + """ + memory_dir = self.config.get_memory_dir() + workspace_dir = self.config.get_workspace() + + # Scan memory/MEMORY.md + memory_file = memory_dir / "MEMORY.md" + if memory_file.exists(): + await self._sync_file(memory_file, "memory", "shared", None) + + # Scan memory directory (including daily summaries) + if memory_dir.exists(): + for file_path in memory_dir.rglob("*.md"): + # Determine scope and user_id from path + rel_path = file_path.relative_to(workspace_dir) + parts = rel_path.parts + + # Check if it's in daily summary directory + if "daily" in parts: + # Daily summary files + if "users" in parts or len(parts) > 3: + # User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md + user_idx = parts.index("daily") + 1 + user_id = parts[user_idx] if user_idx < len(parts) else None + scope = "user" + else: + # Shared daily summary: memory/daily/2024-01-29.md + user_id = None + scope = "shared" + elif "users" in parts: + # User-scoped memory + user_idx = parts.index("users") + 1 + user_id = parts[user_idx] if user_idx < len(parts) else None + scope = "user" + else: + # Shared memory + user_id = None + scope = "shared" + + await self._sync_file(file_path, "memory", scope, user_id) + + self._dirty = False + + async def _sync_file( + self, + file_path: Path, + source: str, + scope: str, + user_id: Optional[str] + ): + """Sync a single file""" + # Compute file hash + content = file_path.read_text() + file_hash = MemoryStorage.compute_hash(content) + + # Get relative path + workspace_dir = self.config.get_workspace() + rel_path = str(file_path.relative_to(workspace_dir)) + + # Check if file changed + stored_hash = self.storage.get_file_hash(rel_path) + if stored_hash == file_hash: + return # No changes + + # Delete old chunks + self.storage.delete_by_path(rel_path) + + # Chunk and embed + chunks = self.chunker.chunk_text(content) + if not chunks: + return + + texts = [chunk.text for chunk in chunks] + if self.embedding_provider: + embeddings = self.embedding_provider.embed_batch(texts) + else: + embeddings = [None] * len(texts) + + # Create memory chunks + memory_chunks = [] + for chunk, embedding in zip(chunks, embeddings): + chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line) + chunk_hash = MemoryStorage.compute_hash(chunk.text) + + memory_chunks.append(MemoryChunk( + id=chunk_id, + agent_id="default", + user_id=user_id, + scope=scope, + source=source, + path=rel_path, + start_line=chunk.start_line, + end_line=chunk.end_line, + text=chunk.text, + embedding=embedding, + hash=chunk_hash, + metadata=None + )) + + # Save + self.storage.save_chunks_batch(memory_chunks) + + # Update file metadata + stat = file_path.stat() + self.storage.update_file_metadata( + path=rel_path, + source=source, + file_hash=file_hash, + mtime=int(stat.st_mtime), + size=stat.st_size + ) + + def should_flush_memory( + self, + current_tokens: int, + context_window: int = 128000, + reserve_tokens: int = 20000, + soft_threshold: int = 4000 + ) -> bool: + """ + Check if memory flush should be triggered + + Args: + current_tokens: Current session token count + context_window: Model's context window size (default: 128K) + reserve_tokens: Reserve tokens for compaction overhead (default: 20K) + soft_threshold: Trigger N tokens before threshold (default: 4K) + + Returns: + True if memory flush should run + """ + return self.flush_manager.should_flush( + current_tokens=current_tokens, + context_window=context_window, + reserve_tokens=reserve_tokens, + soft_threshold=soft_threshold + ) + + async def execute_memory_flush( + self, + agent_executor, + current_tokens: int, + user_id: Optional[str] = None, + **executor_kwargs + ) -> bool: + """ + Execute memory flush before compaction + + This runs a silent agent turn to write durable memories to disk. + Similar to clawdbot's pre-compaction memory flush. + + Args: + agent_executor: Async function to execute agent with prompt + current_tokens: Current session token count + user_id: Optional user ID + **executor_kwargs: Additional kwargs for agent executor + + Returns: + True if flush completed successfully + + Example: + >>> async def run_agent(prompt, system_prompt, silent=False): + ... # Your agent execution logic + ... pass + >>> + >>> if manager.should_flush_memory(current_tokens=100000): + ... await manager.execute_memory_flush( + ... agent_executor=run_agent, + ... current_tokens=100000 + ... ) + """ + success = await self.flush_manager.execute_flush( + agent_executor=agent_executor, + current_tokens=current_tokens, + user_id=user_id, + **executor_kwargs + ) + + if success: + # Mark dirty so next search will sync the new memories + self._dirty = True + + return success + + def build_memory_guidance(self, lang: str = "en", include_context: bool = True) -> str: + """ + Build natural memory guidance for agent system prompt + + Following clawdbot's approach: + 1. Load MEMORY.md as bootstrap context (blends into background) + 2. Load daily files on-demand via memory_search tool + 3. Agent should NOT proactively mention memories unless user asks + + Args: + lang: Language for guidance ("en" or "zh") + include_context: Whether to include bootstrap memory context (default: True) + MEMORY.md is loaded as background context (like clawdbot) + Daily files are accessed via memory_search tool + + Returns: + Memory guidance text (and optionally context) for system prompt + """ + today_file = self.flush_manager.get_today_memory_file().name + + if lang == "zh": + guidance = f"""## 记忆召回 +回答关于过去工作、决策、日期、人物、偏好或待办事项的问题前:先用 memory_search 搜索 MEMORY.md + memory/*.md;然后用 memory_get 只读取需要的行。如果搜索后仍不确定,说明你已检查过。 + +## 记忆存储 +当用户分享持久偏好、决策或重要事实时(无论是否明确要求"记住"),主动存储: +- 持久信息(偏好、决策、人物信息)→ memory/MEMORY.md +- 当天的笔记和上下文 → memory/{today_file} +- 静默存储,仅在用户明确要求时确认 + +## 记忆使用原则 +- 不要主动提起或列举记忆内容 +- 只在用户明确询问相关信息时才使用记忆 +- 记忆是背景知识,不是要展示的内容 +- 自然使用记忆,就像你本来就知道这些信息""" + else: + guidance = f"""## Memory Recall +Before answering anything about prior work, decisions, dates, people, preferences, or todos: run memory_search on MEMORY.md + memory/*.md; then use memory_get to pull only the needed lines. If low confidence after search, say you checked. + +## Memory Storage +When user shares durable preferences, decisions, or important facts (whether or not they explicitly say "remember"), proactively store: +- Durable info (preferences, decisions, people) → memory/MEMORY.md +- Daily notes and context → memory/{today_file} +- Store silently; only confirm when explicitly requested + +## Memory Usage Principles +- Don't proactively mention or list memory contents +- Only use memories when user explicitly asks about them +- Memories are background knowledge, not content to showcase +- Use memories naturally as if you inherently knew this information""" + + if include_context: + # Load bootstrap context (MEMORY.md only, like clawdbot) + bootstrap_context = self.load_bootstrap_memories() + if bootstrap_context: + guidance += f"\n\n## Background Context\n\n{bootstrap_context}" + + return guidance + + def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str: + """ + Load bootstrap memory files for session start + + Following clawdbot's design: + - Only loads memory/MEMORY.md (long-term curated memory) + - Daily files (YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap + - User-specific MEMORY.md is also loaded if user_id provided + + Returns memory content WITHOUT obvious headers so it blends naturally + into the context as background knowledge. + + Args: + user_id: Optional user ID for user-specific memories + + Returns: + Memory content to inject into system prompt (blends naturally as background context) + """ + workspace_dir = self.config.get_workspace() + memory_dir = self.config.get_memory_dir() + + sections = [] + + # 1. Load memory/MEMORY.md ONLY (long-term curated memory) + # Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search + memory_file = memory_dir / "MEMORY.md" + if memory_file.exists(): + try: + content = memory_file.read_text(encoding='utf-8').strip() + if content: + sections.append(content) + except Exception as e: + print(f"Warning: Failed to read memory/MEMORY.md: {e}") + + # 2. Load user-specific MEMORY.md if user_id provided + if user_id: + user_memory_dir = memory_dir / "users" / user_id + user_memory_file = user_memory_dir / "MEMORY.md" + if user_memory_file.exists(): + try: + content = user_memory_file.read_text(encoding='utf-8').strip() + if content: + sections.append(content) + except Exception as e: + print(f"Warning: Failed to read user memory: {e}") + + if not sections: + return "" + + # Join sections without obvious headers - let memories blend naturally + # This makes the agent feel like it "just knows" rather than "checking memory files" + return "\n\n".join(sections) + + def get_status(self) -> Dict[str, Any]: + """Get memory status""" + stats = self.storage.get_stats() + return { + 'chunks': stats['chunks'], + 'files': stats['files'], + 'workspace': str(self.config.get_workspace()), + 'dirty': self._dirty, + 'embedding_enabled': self.embedding_provider is not None, + 'embedding_provider': self.config.embedding_provider if self.embedding_provider else 'disabled', + 'embedding_model': self.config.embedding_model if self.embedding_provider else 'N/A', + 'search_mode': 'hybrid (vector + keyword)' if self.embedding_provider else 'keyword only (FTS5)' + } + + def mark_dirty(self): + """Mark memory as dirty (needs sync)""" + self._dirty = True + + def close(self): + """Close memory manager and release resources""" + self.storage.close() + + # Helper methods + + def _generate_chunk_id(self, path: str, start_line: int, end_line: int) -> str: + """Generate unique chunk ID""" + content = f"{path}:{start_line}:{end_line}" + return hashlib.md5(content.encode('utf-8')).hexdigest() + + def _merge_results( + self, + vector_results: List[SearchResult], + keyword_results: List[SearchResult], + vector_weight: float, + keyword_weight: float + ) -> List[SearchResult]: + """Merge vector and keyword search results""" + # Create a map by (path, start_line, end_line) + merged_map = {} + + for result in vector_results: + key = (result.path, result.start_line, result.end_line) + merged_map[key] = { + 'result': result, + 'vector_score': result.score, + 'keyword_score': 0.0 + } + + for result in keyword_results: + key = (result.path, result.start_line, result.end_line) + if key in merged_map: + merged_map[key]['keyword_score'] = result.score + else: + merged_map[key] = { + 'result': result, + 'vector_score': 0.0, + 'keyword_score': result.score + } + + # Calculate combined scores + merged_results = [] + for entry in merged_map.values(): + combined_score = ( + vector_weight * entry['vector_score'] + + keyword_weight * entry['keyword_score'] + ) + + result = entry['result'] + merged_results.append(SearchResult( + path=result.path, + start_line=result.start_line, + end_line=result.end_line, + score=combined_score, + snippet=result.snippet, + source=result.source, + user_id=result.user_id + )) + + # Sort by score + merged_results.sort(key=lambda r: r.score, reverse=True) + return merged_results diff --git a/agent/memory/storage.py b/agent/memory/storage.py new file mode 100644 index 0000000..1b09615 --- /dev/null +++ b/agent/memory/storage.py @@ -0,0 +1,418 @@ +""" +Storage layer for memory using SQLite + FTS5 + +Provides vector and keyword search capabilities +""" + +import sqlite3 +import json +import hashlib +from typing import List, Dict, Optional, Any +from pathlib import Path +from dataclasses import dataclass + + +@dataclass +class MemoryChunk: + """Represents a memory chunk with text and embedding""" + id: str + user_id: Optional[str] + scope: str # "shared" | "user" | "session" + source: str # "memory" | "session" + path: str + start_line: int + end_line: int + text: str + embedding: Optional[List[float]] + hash: str + metadata: Optional[Dict[str, Any]] = None + + +@dataclass +class SearchResult: + """Search result with score and snippet""" + path: str + start_line: int + end_line: int + score: float + snippet: str + source: str + user_id: Optional[str] = None + + +class MemoryStorage: + """SQLite-based storage with FTS5 for keyword search""" + + def __init__(self, db_path: Path): + self.db_path = db_path + self.conn: Optional[sqlite3.Connection] = None + self._init_db() + + def _init_db(self): + """Initialize database with schema""" + self.conn = sqlite3.connect(str(self.db_path)) + self.conn.row_factory = sqlite3.Row + + # Enable JSON support + self.conn.execute("PRAGMA journal_mode=WAL") + + # Create chunks table with embeddings + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS chunks ( + id TEXT PRIMARY KEY, + user_id TEXT, + scope TEXT NOT NULL DEFAULT 'shared', + source TEXT NOT NULL DEFAULT 'memory', + path TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + text TEXT NOT NULL, + embedding TEXT, + hash TEXT NOT NULL, + metadata TEXT, + created_at INTEGER DEFAULT (strftime('%s', 'now')), + updated_at INTEGER DEFAULT (strftime('%s', 'now')) + ) + """) + + # Create indexes + self.conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_chunks_user + ON chunks(user_id) + """) + + self.conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_chunks_scope + ON chunks(scope) + """) + + self.conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_chunks_hash + ON chunks(path, hash) + """) + + # Create FTS5 virtual table for keyword search + self.conn.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5( + text, + id UNINDEXED, + user_id UNINDEXED, + path UNINDEXED, + source UNINDEXED, + scope UNINDEXED, + content='chunks', + content_rowid='rowid' + ) + """) + + # Create triggers to keep FTS in sync + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN + INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope) + VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope); + END + """) + + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN + DELETE FROM chunks_fts WHERE rowid = old.rowid; + END + """) + + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN + UPDATE chunks_fts SET text = new.text, id = new.id, + user_id = new.user_id, path = new.path, source = new.source, scope = new.scope + WHERE rowid = new.rowid; + END + """) + + # Create files metadata table + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS files ( + path TEXT PRIMARY KEY, + source TEXT NOT NULL DEFAULT 'memory', + hash TEXT NOT NULL, + mtime INTEGER NOT NULL, + size INTEGER NOT NULL, + updated_at INTEGER DEFAULT (strftime('%s', 'now')) + ) + """) + + self.conn.commit() + + def save_chunk(self, chunk: MemoryChunk): + """Save a memory chunk""" + self.conn.execute(""" + INSERT OR REPLACE INTO chunks + (id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) + """, ( + chunk.id, + chunk.user_id, + chunk.scope, + chunk.source, + chunk.path, + chunk.start_line, + chunk.end_line, + chunk.text, + json.dumps(chunk.embedding) if chunk.embedding else None, + chunk.hash, + json.dumps(chunk.metadata) if chunk.metadata else None + )) + self.conn.commit() + + def save_chunks_batch(self, chunks: List[MemoryChunk]): + """Save multiple chunks in a batch""" + self.conn.executemany(""" + INSERT OR REPLACE INTO chunks + (id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) + """, [ + ( + c.id, c.user_id, c.scope, c.source, c.path, + c.start_line, c.end_line, c.text, + json.dumps(c.embedding) if c.embedding else None, + c.hash, + json.dumps(c.metadata) if c.metadata else None + ) + for c in chunks + ]) + self.conn.commit() + + def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]: + """Get a chunk by ID""" + row = self.conn.execute(""" + SELECT * FROM chunks WHERE id = ? + """, (chunk_id,)).fetchone() + + if not row: + return None + + return self._row_to_chunk(row) + + def search_vector( + self, + query_embedding: List[float], + user_id: Optional[str] = None, + scopes: List[str] = None, + limit: int = 10 + ) -> List[SearchResult]: + """ + Vector similarity search using in-memory cosine similarity + (sqlite-vec can be added later for better performance) + """ + if scopes is None: + scopes = ["shared"] + if user_id: + scopes.append("user") + + # Build query + scope_placeholders = ','.join('?' * len(scopes)) + params = scopes + + if user_id: + query = f""" + SELECT * FROM chunks + WHERE scope IN ({scope_placeholders}) + AND (scope = 'shared' OR user_id = ?) + AND embedding IS NOT NULL + """ + params.append(user_id) + else: + query = f""" + SELECT * FROM chunks + WHERE scope IN ({scope_placeholders}) + AND embedding IS NOT NULL + """ + + rows = self.conn.execute(query, params).fetchall() + + # Calculate cosine similarity + results = [] + for row in rows: + embedding = json.loads(row['embedding']) + similarity = self._cosine_similarity(query_embedding, embedding) + + if similarity > 0: + results.append((similarity, row)) + + # Sort by similarity and limit + results.sort(key=lambda x: x[0], reverse=True) + results = results[:limit] + + return [ + SearchResult( + path=row['path'], + start_line=row['start_line'], + end_line=row['end_line'], + score=score, + snippet=self._truncate_text(row['text'], 500), + source=row['source'], + user_id=row['user_id'] + ) + for score, row in results + ] + + def search_keyword( + self, + query: str, + user_id: Optional[str] = None, + scopes: List[str] = None, + limit: int = 10 + ) -> List[SearchResult]: + """Keyword search using FTS5""" + if scopes is None: + scopes = ["shared"] + if user_id: + scopes.append("user") + + # Build FTS query + fts_query = self._build_fts_query(query) + if not fts_query: + return [] + + scope_placeholders = ','.join('?' * len(scopes)) + params = [fts_query] + scopes + + if user_id: + sql_query = f""" + SELECT chunks.*, bm25(chunks_fts) as rank + FROM chunks_fts + JOIN chunks ON chunks.id = chunks_fts.id + WHERE chunks_fts MATCH ? + AND chunks.scope IN ({scope_placeholders}) + AND (chunks.scope = 'shared' OR chunks.user_id = ?) + ORDER BY rank + LIMIT ? + """ + params.extend([user_id, limit]) + else: + sql_query = f""" + SELECT chunks.*, bm25(chunks_fts) as rank + FROM chunks_fts + JOIN chunks ON chunks.id = chunks_fts.id + WHERE chunks_fts MATCH ? + AND chunks.scope IN ({scope_placeholders}) + ORDER BY rank + LIMIT ? + """ + params.append(limit) + + rows = self.conn.execute(sql_query, params).fetchall() + + return [ + SearchResult( + path=row['path'], + start_line=row['start_line'], + end_line=row['end_line'], + score=self._bm25_rank_to_score(row['rank']), + snippet=self._truncate_text(row['text'], 500), + source=row['source'], + user_id=row['user_id'] + ) + for row in rows + ] + + def delete_by_path(self, path: str): + """Delete all chunks from a file""" + self.conn.execute(""" + DELETE FROM chunks WHERE path = ? + """, (path,)) + self.conn.commit() + + def get_file_hash(self, path: str) -> Optional[str]: + """Get stored file hash""" + row = self.conn.execute(""" + SELECT hash FROM files WHERE path = ? + """, (path,)).fetchone() + return row['hash'] if row else None + + def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int): + """Update file metadata""" + self.conn.execute(""" + INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at) + VALUES (?, ?, ?, ?, ?, strftime('%s', 'now')) + """, (path, source, file_hash, mtime, size)) + self.conn.commit() + + def get_stats(self) -> Dict[str, int]: + """Get storage statistics""" + chunks_count = self.conn.execute(""" + SELECT COUNT(*) as cnt FROM chunks + """).fetchone()['cnt'] + + files_count = self.conn.execute(""" + SELECT COUNT(*) as cnt FROM files + """).fetchone()['cnt'] + + return { + 'chunks': chunks_count, + 'files': files_count + } + + def close(self): + """Close database connection""" + if self.conn: + self.conn.close() + + # Helper methods + + def _row_to_chunk(self, row) -> MemoryChunk: + """Convert database row to MemoryChunk""" + return MemoryChunk( + id=row['id'], + user_id=row['user_id'], + scope=row['scope'], + source=row['source'], + path=row['path'], + start_line=row['start_line'], + end_line=row['end_line'], + text=row['text'], + embedding=json.loads(row['embedding']) if row['embedding'] else None, + hash=row['hash'], + metadata=json.loads(row['metadata']) if row['metadata'] else None + ) + + @staticmethod + def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float: + """Calculate cosine similarity between two vectors""" + if len(vec1) != len(vec2): + return 0.0 + + dot_product = sum(a * b for a, b in zip(vec1, vec2)) + norm1 = sum(a * a for a in vec1) ** 0.5 + norm2 = sum(b * b for b in vec2) ** 0.5 + + if norm1 == 0 or norm2 == 0: + return 0.0 + + return dot_product / (norm1 * norm2) + + @staticmethod + def _build_fts_query(raw_query: str) -> Optional[str]: + """Build FTS5 query from raw text""" + import re + tokens = re.findall(r'[A-Za-z0-9_\u4e00-\u9fff]+', raw_query) + if not tokens: + return None + quoted = [f'"{t}"' for t in tokens] + return ' AND '.join(quoted) + + @staticmethod + def _bm25_rank_to_score(rank: float) -> float: + """Convert BM25 rank to 0-1 score""" + normalized = max(0, rank) if rank is not None else 999 + return 1 / (1 + normalized) + + @staticmethod + def _truncate_text(text: str, max_chars: int) -> str: + """Truncate text to max characters""" + if len(text) <= max_chars: + return text + return text[:max_chars] + "..." + + @staticmethod + def compute_hash(content: str) -> str: + """Compute SHA256 hash of content""" + return hashlib.sha256(content.encode('utf-8')).hexdigest() diff --git a/agent/memory/summarizer.py b/agent/memory/summarizer.py new file mode 100644 index 0000000..ca1ac30 --- /dev/null +++ b/agent/memory/summarizer.py @@ -0,0 +1,235 @@ +""" +Memory flush manager + +Triggers memory flush before context compaction (similar to clawdbot) +""" + +from typing import Optional, Callable, Any +from pathlib import Path +from datetime import datetime + + +class MemoryFlushManager: + """ + Manages memory flush operations before context compaction + + Similar to clawdbot's memory flush mechanism: + - Triggers when context approaches token limit + - Runs a silent agent turn to write memories to disk + - Uses memory/YYYY-MM-DD.md for daily notes + - Uses MEMORY.md for long-term curated memories + """ + + def __init__( + self, + workspace_dir: Path, + llm_model: Optional[Any] = None + ): + """ + Initialize memory flush manager + + Args: + workspace_dir: Workspace directory + llm_model: LLM model for agent execution (optional) + """ + self.workspace_dir = workspace_dir + self.llm_model = llm_model + + self.memory_dir = workspace_dir / "memory" + self.memory_dir.mkdir(parents=True, exist_ok=True) + + # Tracking + self.last_flush_token_count: Optional[int] = None + self.last_flush_timestamp: Optional[datetime] = None + + def should_flush( + self, + current_tokens: int, + context_window: int, + reserve_tokens: int = 20000, + soft_threshold: int = 4000 + ) -> bool: + """ + Determine if memory flush should be triggered + + Similar to clawdbot's shouldRunMemoryFlush logic: + threshold = contextWindow - reserveTokens - softThreshold + + Args: + current_tokens: Current session token count + context_window: Model's context window size + reserve_tokens: Reserve tokens for compaction overhead + soft_threshold: Trigger flush N tokens before threshold + + Returns: + True if flush should run + """ + if current_tokens <= 0: + return False + + threshold = max(0, context_window - reserve_tokens - soft_threshold) + if threshold <= 0: + return False + + # Check if we've crossed the threshold + if current_tokens < threshold: + return False + + # Avoid duplicate flush in same compaction cycle + if self.last_flush_token_count is not None: + if current_tokens <= self.last_flush_token_count + soft_threshold: + return False + + return True + + def get_today_memory_file(self, user_id: Optional[str] = None) -> Path: + """ + Get today's memory file path: memory/YYYY-MM-DD.md + + Args: + user_id: Optional user ID for user-specific memory + + Returns: + Path to today's memory file + """ + today = datetime.now().strftime("%Y-%m-%d") + + if user_id: + user_dir = self.memory_dir / "users" / user_id + user_dir.mkdir(parents=True, exist_ok=True) + return user_dir / f"{today}.md" + else: + return self.memory_dir / f"{today}.md" + + def get_main_memory_file(self, user_id: Optional[str] = None) -> Path: + """ + Get main memory file path: memory/MEMORY.md + + Args: + user_id: Optional user ID for user-specific memory + + Returns: + Path to main memory file + """ + if user_id: + user_dir = self.memory_dir / "users" / user_id + user_dir.mkdir(parents=True, exist_ok=True) + return user_dir / "MEMORY.md" + else: + return self.memory_dir / "MEMORY.md" + + def create_flush_prompt(self) -> str: + """ + Create prompt for memory flush turn + + Similar to clawdbot's DEFAULT_MEMORY_FLUSH_PROMPT + """ + today = datetime.now().strftime("%Y-%m-%d") + return ( + f"Pre-compaction memory flush. " + f"Store durable memories now (use memory/{today}.md for daily notes; " + f"create memory/ if needed). " + f"If nothing to store, reply with NO_REPLY." + ) + + def create_flush_system_prompt(self) -> str: + """ + Create system prompt for memory flush turn + + Similar to clawdbot's DEFAULT_MEMORY_FLUSH_SYSTEM_PROMPT + """ + return ( + "Pre-compaction memory flush turn. " + "The session is near auto-compaction; capture durable memories to disk. " + "You may reply, but usually NO_REPLY is correct." + ) + + async def execute_flush( + self, + agent_executor: Callable, + current_tokens: int, + user_id: Optional[str] = None, + **executor_kwargs + ) -> bool: + """ + Execute memory flush by running a silent agent turn + + Args: + agent_executor: Function to execute agent with prompt + current_tokens: Current token count + user_id: Optional user ID + **executor_kwargs: Additional kwargs for agent executor + + Returns: + True if flush completed successfully + """ + try: + # Create flush prompts + prompt = self.create_flush_prompt() + system_prompt = self.create_flush_system_prompt() + + # Execute agent turn (silent, no user-visible reply expected) + await agent_executor( + prompt=prompt, + system_prompt=system_prompt, + silent=True, # NO_REPLY expected + **executor_kwargs + ) + + # Track flush + self.last_flush_token_count = current_tokens + self.last_flush_timestamp = datetime.now() + + return True + + except Exception as e: + print(f"Memory flush failed: {e}") + return False + + def get_status(self) -> dict: + """Get memory flush status""" + return { + 'last_flush_tokens': self.last_flush_token_count, + 'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None, + 'today_file': str(self.get_today_memory_file()), + 'main_file': str(self.get_main_memory_file()) + } + + +def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None): + """ + Create default memory files if they don't exist + + Args: + workspace_dir: Workspace directory + user_id: Optional user ID for user-specific files + """ + memory_dir = workspace_dir / "memory" + memory_dir.mkdir(parents=True, exist_ok=True) + + # Create main MEMORY.md in memory directory + if user_id: + user_dir = memory_dir / "users" / user_id + user_dir.mkdir(parents=True, exist_ok=True) + main_memory = user_dir / "MEMORY.md" + else: + main_memory = memory_dir / "MEMORY.md" + + if not main_memory.exists(): + # Create empty file or with minimal structure (no obvious "Memory" header) + # Following clawdbot's approach: memories should blend naturally into context + main_memory.write_text("") + + # Create today's memory file + today = datetime.now().strftime("%Y-%m-%d") + if user_id: + user_dir = memory_dir / "users" / user_id + today_memory = user_dir / f"{today}.md" + else: + today_memory = memory_dir / f"{today}.md" + + if not today_memory.exists(): + today_memory.write_text( + f"# Daily Memory: {today}\n\n" + f"Day-to-day notes and running context.\n\n" + ) diff --git a/agent/memory/tools/__init__.py b/agent/memory/tools/__init__.py new file mode 100644 index 0000000..2f7a5d0 --- /dev/null +++ b/agent/memory/tools/__init__.py @@ -0,0 +1,10 @@ +""" +Memory tools for AgentMesh + +Provides memory_search and memory_get tools for agents +""" + +from agent.memory.tools.memory_search import MemorySearchTool +from agent.memory.tools.memory_get import MemoryGetTool + +__all__ = ['MemorySearchTool', 'MemoryGetTool'] diff --git a/agent/memory/tools/memory_get.py b/agent/memory/tools/memory_get.py new file mode 100644 index 0000000..e9af36c --- /dev/null +++ b/agent/memory/tools/memory_get.py @@ -0,0 +1,118 @@ +""" +Memory get tool + +Allows agents to read specific sections from memory files +""" + +from typing import Dict, Any, Optional +from pathlib import Path +from agent.tools.base_tool import BaseTool +from agent.memory.manager import MemoryManager + + +class MemoryGetTool(BaseTool): + """Tool for reading memory file contents""" + + def __init__(self, memory_manager: MemoryManager): + """ + Initialize memory get tool + + Args: + memory_manager: MemoryManager instance + """ + super().__init__() + self.memory_manager = memory_manager + self._name = "memory_get" + self._description = ( + "Read specific memory file content by path and line range. " + "Use after memory_search to get full context from historical memory files." + ) + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return self._description + + @property + def parameters(self) -> Dict[str, Any]: + return { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')" + }, + "start_line": { + "type": "integer", + "description": "Starting line number (optional, default: 1)", + "default": 1 + }, + "num_lines": { + "type": "integer", + "description": "Number of lines to read (optional, reads all if not specified)" + } + }, + "required": ["path"] + } + + async def execute(self, **kwargs) -> str: + """ + Execute memory file read + + Args: + path: File path + start_line: Start line + num_lines: Number of lines + + Returns: + File content + """ + path = kwargs.get("path") + start_line = kwargs.get("start_line", 1) + num_lines = kwargs.get("num_lines") + + if not path: + return "Error: path parameter is required" + + try: + workspace_dir = self.memory_manager.config.get_workspace() + file_path = workspace_dir / path + + if not file_path.exists(): + return f"Error: File not found: {path}" + + content = file_path.read_text() + lines = content.split('\n') + + # Handle line range + if start_line < 1: + start_line = 1 + + start_idx = start_line - 1 + + if num_lines: + end_idx = start_idx + num_lines + selected_lines = lines[start_idx:end_idx] + else: + selected_lines = lines[start_idx:] + + result = '\n'.join(selected_lines) + + # Add metadata + total_lines = len(lines) + shown_lines = len(selected_lines) + + output = [ + f"File: {path}", + f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})", + "", + result + ] + + return '\n'.join(output) + + except Exception as e: + return f"Error reading memory file: {str(e)}" diff --git a/agent/memory/tools/memory_search.py b/agent/memory/tools/memory_search.py new file mode 100644 index 0000000..1cfda07 --- /dev/null +++ b/agent/memory/tools/memory_search.py @@ -0,0 +1,106 @@ +""" +Memory search tool + +Allows agents to search their memory using semantic and keyword search +""" + +from typing import Dict, Any, Optional +from agent.tools.base_tool import BaseTool +from agent.memory.manager import MemoryManager + + +class MemorySearchTool(BaseTool): + """Tool for searching agent memory""" + + def __init__(self, memory_manager: MemoryManager, user_id: Optional[str] = None): + """ + Initialize memory search tool + + Args: + memory_manager: MemoryManager instance + user_id: Optional user ID for scoped search + """ + super().__init__() + self.memory_manager = memory_manager + self.user_id = user_id + self._name = "memory_search" + self._description = ( + "Search historical memory files (beyond today/yesterday) using semantic and keyword search. " + "Recent context (MEMORY.md + today + yesterday) is already loaded. " + "Use this ONLY for older dates, specific past events, or when current context lacks needed info." + ) + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return self._description + + @property + def parameters(self) -> Dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query (can be natural language question or keywords)" + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results to return (default: 10)", + "default": 10 + }, + "min_score": { + "type": "number", + "description": "Minimum relevance score (0-1, default: 0.3)", + "default": 0.3 + } + }, + "required": ["query"] + } + + async def execute(self, **kwargs) -> str: + """ + Execute memory search + + Args: + query: Search query + max_results: Maximum results + min_score: Minimum score + + Returns: + Formatted search results + """ + query = kwargs.get("query") + max_results = kwargs.get("max_results", 10) + min_score = kwargs.get("min_score", 0.3) + + if not query: + return "Error: query parameter is required" + + try: + results = await self.memory_manager.search( + query=query, + user_id=self.user_id, + max_results=max_results, + min_score=min_score, + include_shared=True + ) + + if not results: + return f"No relevant memories found for query: {query}" + + # Format results + output = [f"Found {len(results)} relevant memories:\n"] + + for i, result in enumerate(results, 1): + output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})") + output.append(f" Score: {result.score:.3f}") + output.append(f" Snippet: {result.snippet}") + + return "\n".join(output) + + except Exception as e: + return f"Error searching memory: {str(e)}" diff --git a/agent/protocol/__init__.py b/agent/protocol/__init__.py new file mode 100644 index 0000000..a9fe5a3 --- /dev/null +++ b/agent/protocol/__init__.py @@ -0,0 +1,20 @@ +from .agent import Agent +from .agent_stream import AgentStreamExecutor +from .task import Task, TaskType, TaskStatus +from .result import AgentResult, AgentAction, AgentActionType, ToolResult +from .models import LLMModel, LLMRequest, ModelFactory + +__all__ = [ + 'Agent', + 'AgentStreamExecutor', + 'Task', + 'TaskType', + 'TaskStatus', + 'AgentResult', + 'AgentAction', + 'AgentActionType', + 'ToolResult', + 'LLMModel', + 'LLMRequest', + 'ModelFactory' +] \ No newline at end of file diff --git a/agent/protocol/agent.py b/agent/protocol/agent.py new file mode 100644 index 0000000..35dd225 --- /dev/null +++ b/agent/protocol/agent.py @@ -0,0 +1,292 @@ +import json +import time + +from common.log import logger +from agent.protocol.models import LLMRequest, LLMModel +from agent.protocol.agent_stream import AgentStreamExecutor +from agent.protocol.result import AgentAction, AgentActionType, ToolResult, AgentResult +from agent.tools.base_tool import BaseTool, ToolStage + + +class Agent: + def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None, + tools=None, output_mode="print", max_steps=100, max_context_tokens=None, + context_reserve_tokens=None, memory_manager=None, name: str = None): + """ + Initialize the Agent with system prompt, model, description. + + :param system_prompt: The system prompt for the agent. + :param description: A description of the agent. + :param model: An instance of LLMModel to be used by the agent. + :param tools: Optional list of tools for the agent to use. + :param output_mode: Control how execution progress is displayed: + "print" for console output or "logger" for using logger + :param max_steps: Maximum number of steps the agent can take (default: 100) + :param max_context_tokens: Maximum tokens to keep in context (default: None, auto-calculated based on model) + :param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated) + :param memory_manager: Optional MemoryManager instance for memory operations + :param name: [Deprecated] The name of the agent (no longer used in single-agent system) + """ + self.name = name or "Agent" + self.system_prompt = system_prompt + self.model: LLMModel = model # Instance of LLMModel + self.description = description + self.tools: list = [] + self.max_steps = max_steps # max tool-call steps, default 100 + self.max_context_tokens = max_context_tokens # max tokens in context + self.context_reserve_tokens = context_reserve_tokens # reserve tokens for new requests + self.captured_actions = [] # Initialize captured actions list + self.output_mode = output_mode + self.last_usage = None # Store last API response usage info + self.messages = [] # Unified message history for stream mode + self.memory_manager = memory_manager # Memory manager for auto memory flush + if tools: + for tool in tools: + self.add_tool(tool) + + def add_tool(self, tool: BaseTool): + """ + Add a tool to the agent. + + :param tool: The tool to add (either a tool instance or a tool name) + """ + # If tool is already an instance, use it directly + tool.model = self.model + self.tools.append(tool) + + def _get_model_context_window(self) -> int: + """ + Get the model's context window size in tokens. + Auto-detect based on model name. + + Model context windows: + - Claude 3.5/3.7 Sonnet: 200K tokens + - Claude 3 Opus: 200K tokens + - GPT-4 Turbo/128K: 128K tokens + - GPT-4: 8K-32K tokens + - GPT-3.5: 16K tokens + - DeepSeek: 64K tokens + + :return: Context window size in tokens + """ + if self.model and hasattr(self.model, 'model'): + model_name = self.model.model.lower() + + # Claude models - 200K context + if 'claude-3' in model_name or 'claude-sonnet' in model_name: + return 200000 + + # GPT-4 models + elif 'gpt-4' in model_name: + if 'turbo' in model_name or '128k' in model_name: + return 128000 + elif '32k' in model_name: + return 32000 + else: + return 8000 + + # GPT-3.5 + elif 'gpt-3.5' in model_name: + if '16k' in model_name: + return 16000 + else: + return 4000 + + # DeepSeek + elif 'deepseek' in model_name: + return 64000 + + # Default conservative value + return 10000 + + def _get_context_reserve_tokens(self) -> int: + """ + Get the number of tokens to reserve for new requests. + This prevents context overflow by keeping a buffer. + + :return: Number of tokens to reserve + """ + if self.context_reserve_tokens is not None: + return self.context_reserve_tokens + + # Reserve ~20% of context window for new requests + context_window = self._get_model_context_window() + return max(4000, int(context_window * 0.2)) + + def _estimate_message_tokens(self, message: dict) -> int: + """ + Estimate token count for a message using chars/4 heuristic. + This is a conservative estimate (tends to overestimate). + + :param message: Message dict with 'role' and 'content' + :return: Estimated token count + """ + content = message.get('content', '') + if isinstance(content, str): + return max(1, len(content) // 4) + elif isinstance(content, list): + # Handle multi-part content (text + images) + total_chars = 0 + for part in content: + if isinstance(part, dict) and part.get('type') == 'text': + total_chars += len(part.get('text', '')) + elif isinstance(part, dict) and part.get('type') == 'image': + # Estimate images as ~1200 tokens + total_chars += 4800 + return max(1, total_chars // 4) + return 1 + + def _find_tool(self, tool_name: str): + """Find and return a tool with the specified name""" + for tool in self.tools: + if tool.name == tool_name: + # Only pre-process stage tools can be actively called + if tool.stage == ToolStage.PRE_PROCESS: + tool.model = self.model + tool.context = self # Set tool context + return tool + else: + # If it's a post-process tool, return None to prevent direct calling + logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.") + return None + return None + + # output function based on mode + def output(self, message="", end="\n"): + if self.output_mode == "print": + print(message, end=end) + elif message: + logger.info(message) + + def _execute_post_process_tools(self): + """Execute all post-process stage tools""" + # Get all post-process stage tools + post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS] + + # Execute each tool + for tool in post_process_tools: + # Set tool context + tool.context = self + + # Record start time for execution timing + start_time = time.time() + + # Execute tool (with empty parameters, tool will extract needed info from context) + result = tool.execute({}) + + # Calculate execution time + execution_time = time.time() - start_time + + # Capture tool use for tracking + self.capture_tool_use( + tool_name=tool.name, + input_params={}, # Post-process tools typically don't take parameters + output=result.result, + status=result.status, + error_message=str(result.result) if result.status == "error" else None, + execution_time=execution_time + ) + + # Log result + if result.status == "success": + # Print tool execution result in the desired format + self.output(f"\n🛠️ {tool.name}: {json.dumps(result.result)}") + else: + # Print failure in print mode + self.output(f"\n🛠️ {tool.name}: {json.dumps({'status': 'error', 'message': str(result.result)})}") + + def capture_tool_use(self, tool_name, input_params, output, status, thought=None, error_message=None, + execution_time=0.0): + """ + Capture a tool use action. + + :param thought: thought content + :param tool_name: Name of the tool used + :param input_params: Parameters passed to the tool + :param output: Output from the tool + :param status: Status of the tool execution + :param error_message: Error message if the tool execution failed + :param execution_time: Time taken to execute the tool + """ + tool_result = ToolResult( + tool_name=tool_name, + input_params=input_params, + output=output, + status=status, + error_message=error_message, + execution_time=execution_time + ) + + action = AgentAction( + agent_id=self.id if hasattr(self, 'id') else str(id(self)), + agent_name=self.name, + action_type=AgentActionType.TOOL_USE, + tool_result=tool_result, + thought=thought + ) + + self.captured_actions.append(action) + + return action + + def run_stream(self, user_message: str, on_event=None, clear_history: bool = False) -> str: + """ + Execute single agent task with streaming (based on tool-call) + + This method supports: + - Streaming output + - Multi-turn reasoning based on tool-call + - Event callbacks + - Persistent conversation history across calls + + Args: + user_message: User message + on_event: Event callback function callback(event: dict) + event = {"type": str, "timestamp": float, "data": dict} + clear_history: If True, clear conversation history before this call (default: False) + + Returns: + Final response text + + Example: + # Multi-turn conversation with memory + response1 = agent.run_stream("My name is Alice") + response2 = agent.run_stream("What's my name?") # Will remember Alice + + # Single-turn without memory + response = agent.run_stream("Hello", clear_history=True) + """ + # Clear history if requested + if clear_history: + self.messages = [] + + # Get model to use + if not self.model: + raise ValueError("No model available for agent") + + # Create stream executor with agent's message history + executor = AgentStreamExecutor( + agent=self, + model=self.model, + system_prompt=self.system_prompt, + tools=self.tools, + max_turns=self.max_steps, + on_event=on_event, + messages=self.messages # Pass agent's message history + ) + + # Execute + response = executor.run_stream(user_message) + + # Update agent's message history from executor + self.messages = executor.messages + + # Execute all post-process tools + self._execute_post_process_tools() + + return response + + def clear_history(self): + """Clear conversation history and captured actions""" + self.messages = [] + self.captured_actions = [] \ No newline at end of file diff --git a/agent/protocol/agent_stream.py b/agent/protocol/agent_stream.py new file mode 100644 index 0000000..e7b3418 --- /dev/null +++ b/agent/protocol/agent_stream.py @@ -0,0 +1,461 @@ +""" +Agent Stream Execution Module - Multi-turn reasoning based on tool-call + +Provides streaming output, event system, and complete tool-call loop +""" +import json +import time +from typing import List, Dict, Any, Optional, Callable + +from common.log import logger +from agent.protocol.models import LLMRequest, LLMModel +from agent.tools.base_tool import BaseTool, ToolResult + + +class AgentStreamExecutor: + """ + Agent Stream Executor + + Handles multi-turn reasoning loop based on tool-call: + 1. LLM generates response (may include tool calls) + 2. Execute tools + 3. Return results to LLM + 4. Repeat until no more tool calls + """ + + def __init__( + self, + agent, # Agent instance + model: LLMModel, + system_prompt: str, + tools: List[BaseTool], + max_turns: int = 50, + on_event: Optional[Callable] = None, + messages: Optional[List[Dict]] = None + ): + """ + Initialize stream executor + + Args: + agent: Agent instance (for accessing context) + model: LLM model + system_prompt: System prompt + tools: List of available tools + max_turns: Maximum number of turns + on_event: Event callback function + messages: Optional existing message history (for persistent conversations) + """ + self.agent = agent + self.model = model + self.system_prompt = system_prompt + # Convert tools list to dict + self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools + self.max_turns = max_turns + self.on_event = on_event + + # Message history - use provided messages or create new list + self.messages = messages if messages is not None else [] + + def _emit_event(self, event_type: str, data: dict = None): + """Emit event""" + if self.on_event: + try: + self.on_event({ + "type": event_type, + "timestamp": time.time(), + "data": data or {} + }) + except Exception as e: + logger.error(f"Event callback error: {e}") + + def run_stream(self, user_message: str) -> str: + """ + Execute streaming reasoning loop + + Args: + user_message: User message + + Returns: + Final response text + """ + # Log user message + logger.info(f"\n{'='*50}") + logger.info(f"👤 用户: {user_message}") + logger.info(f"{'='*50}") + + # Add user message (Claude format - use content blocks for consistency) + self.messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": user_message + } + ] + }) + + self._emit_event("agent_start") + + final_response = "" + turn = 0 + + try: + while turn < self.max_turns: + turn += 1 + logger.info(f"\n{'='*50} 第 {turn} 轮 {'='*50}") + self._emit_event("turn_start", {"turn": turn}) + + # Check if memory flush is needed (before calling LLM) + if self.agent.memory_manager and hasattr(self.agent, 'last_usage'): + usage = self.agent.last_usage + if usage and 'input_tokens' in usage: + current_tokens = usage.get('input_tokens', 0) + context_window = self.agent._get_model_context_window() + reserve_tokens = self.agent.context_reserve_tokens or 20000 + + if self.agent.memory_manager.should_flush_memory( + current_tokens=current_tokens, + context_window=context_window, + reserve_tokens=reserve_tokens + ): + self._emit_event("memory_flush_start", { + "current_tokens": current_tokens, + "threshold": context_window - reserve_tokens - 4000 + }) + + # TODO: Execute memory flush in background + # This would require async support + logger.info(f"Memory flush recommended at {current_tokens} tokens") + + # Call LLM + assistant_msg, tool_calls = self._call_llm_stream() + final_response = assistant_msg + + # No tool calls, end loop + if not tool_calls: + if assistant_msg: + logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}") + logger.info(f"✅ 完成 (无工具调用)") + self._emit_event("turn_end", { + "turn": turn, + "has_tool_calls": False + }) + break + + # Log tool calls in compact format + tool_names = [tc['name'] for tc in tool_calls] + logger.info(f"🔧 调用工具: {', '.join(tool_names)}") + + # Execute tools + tool_results = [] + tool_result_blocks = [] + + for tool_call in tool_calls: + result = self._execute_tool(tool_call) + tool_results.append(result) + + # Log tool result in compact format + status_emoji = "✅" if result.get("status") == "success" else "❌" + result_str = str(result.get('result', '')) + logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}") + + # Build tool result block (Claude format) + # Content should be a string representation of the result + result_content = json.dumps(result) if not isinstance(result, str) else result + tool_result_blocks.append({ + "type": "tool_result", + "tool_use_id": tool_call["id"], + "content": result_content + }) + + # Add tool results to message history as user message (Claude format) + self.messages.append({ + "role": "user", + "content": tool_result_blocks + }) + + self._emit_event("turn_end", { + "turn": turn, + "has_tool_calls": True, + "tool_count": len(tool_calls) + }) + + if turn >= self.max_turns: + logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}") + + except Exception as e: + logger.error(f"❌ Agent执行错误: {e}") + self._emit_event("error", {"error": str(e)}) + raise + + finally: + logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n") + self._emit_event("agent_end", {"final_response": final_response}) + + return final_response + + def _call_llm_stream(self) -> tuple[str, List[Dict]]: + """ + Call LLM with streaming + + Returns: + (response_text, tool_calls) + """ + # Trim messages if needed (using agent's context management) + self._trim_messages() + + # Prepare messages + messages = self._prepare_messages() + + # Debug: log message structure + logger.debug(f"Sending {len(messages)} messages to LLM") + for i, msg in enumerate(messages): + role = msg.get("role", "unknown") + content = msg.get("content", "") + if isinstance(content, list): + content_types = [c.get("type") for c in content if isinstance(c, dict)] + logger.debug(f" Message {i}: role={role}, content_blocks={content_types}") + else: + logger.debug(f" Message {i}: role={role}, content_length={len(str(content))}") + + # Prepare tool definitions (OpenAI/Claude format) + tools_schema = None + if self.tools: + tools_schema = [] + for tool in self.tools.values(): + tools_schema.append({ + "name": tool.name, + "description": tool.description, + "input_schema": tool.params # Claude uses input_schema + }) + + # Create request + request = LLMRequest( + messages=messages, + temperature=0, + stream=True, + tools=tools_schema, + system=self.system_prompt # Pass system prompt separately for Claude API + ) + + self._emit_event("message_start", {"role": "assistant"}) + + # Streaming response + full_content = "" + tool_calls_buffer = {} # {index: {id, name, arguments}} + + try: + stream = self.model.call_stream(request) + + for chunk in stream: + # Check for errors + if isinstance(chunk, dict) and chunk.get("error"): + error_msg = chunk.get("message", "Unknown error") + status_code = chunk.get("status_code", "N/A") + logger.error(f"API Error: {error_msg} (Status: {status_code})") + logger.error(f"Full error chunk: {chunk}") + raise Exception(f"{error_msg} (Status: {status_code})") + + # Parse chunk + if isinstance(chunk, dict) and "choices" in chunk: + choice = chunk["choices"][0] + delta = choice.get("delta", {}) + + # Handle text content + if "content" in delta and delta["content"]: + content_delta = delta["content"] + full_content += content_delta + self._emit_event("message_update", {"delta": content_delta}) + + # Handle tool calls + if "tool_calls" in delta: + for tc_delta in delta["tool_calls"]: + index = tc_delta.get("index", 0) + + if index not in tool_calls_buffer: + tool_calls_buffer[index] = { + "id": "", + "name": "", + "arguments": "" + } + + if "id" in tc_delta: + tool_calls_buffer[index]["id"] = tc_delta["id"] + + if "function" in tc_delta: + func = tc_delta["function"] + if "name" in func: + tool_calls_buffer[index]["name"] = func["name"] + if "arguments" in func: + tool_calls_buffer[index]["arguments"] += func["arguments"] + + except Exception as e: + logger.error(f"LLM call error: {e}") + raise + + # Parse tool calls + tool_calls = [] + for idx in sorted(tool_calls_buffer.keys()): + tc = tool_calls_buffer[idx] + try: + arguments = json.loads(tc["arguments"]) if tc["arguments"] else {} + except json.JSONDecodeError as e: + logger.error(f"Failed to parse tool arguments: {tc['arguments']}") + arguments = {} + + tool_calls.append({ + "id": tc["id"], + "name": tc["name"], + "arguments": arguments + }) + + # Add assistant message to history (Claude format uses content blocks) + assistant_msg = {"role": "assistant", "content": []} + + # Add text content block if present + if full_content: + assistant_msg["content"].append({ + "type": "text", + "text": full_content + }) + + # Add tool_use blocks if present + if tool_calls: + for tc in tool_calls: + assistant_msg["content"].append({ + "type": "tool_use", + "id": tc["id"], + "name": tc["name"], + "input": tc["arguments"] + }) + + # Only append if content is not empty + if assistant_msg["content"]: + self.messages.append(assistant_msg) + + self._emit_event("message_end", { + "content": full_content, + "tool_calls": tool_calls + }) + + return full_content, tool_calls + + def _execute_tool(self, tool_call: Dict) -> Dict[str, Any]: + """ + Execute tool + + Args: + tool_call: {"id": str, "name": str, "arguments": dict} + + Returns: + Tool execution result + """ + tool_name = tool_call["name"] + tool_id = tool_call["id"] + arguments = tool_call["arguments"] + + self._emit_event("tool_execution_start", { + "tool_call_id": tool_id, + "tool_name": tool_name, + "arguments": arguments + }) + + try: + tool = self.tools.get(tool_name) + if not tool: + raise ValueError(f"Tool '{tool_name}' not found") + + # Set tool context + tool.model = self.model + tool.context = self.agent + + # Execute tool + start_time = time.time() + result: ToolResult = tool.execute_tool(arguments) + execution_time = time.time() - start_time + + result_dict = { + "status": result.status, + "result": result.result, + "execution_time": execution_time + } + + self._emit_event("tool_execution_end", { + "tool_call_id": tool_id, + "tool_name": tool_name, + **result_dict + }) + + return result_dict + + except Exception as e: + logger.error(f"Tool execution error: {e}") + error_result = { + "status": "error", + "result": str(e), + "execution_time": 0 + } + self._emit_event("tool_execution_end", { + "tool_call_id": tool_id, + "tool_name": tool_name, + **error_result + }) + return error_result + + def _trim_messages(self): + """ + Trim message history to stay within context limits. + Uses agent's context management configuration. + """ + if not self.messages or not self.agent: + return + + # Get context window and reserve tokens from agent + context_window = self.agent._get_model_context_window() + reserve_tokens = self.agent._get_context_reserve_tokens() + max_tokens = context_window - reserve_tokens + + # Estimate current tokens + current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages) + + # Add system prompt tokens + system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt}) + current_tokens += system_tokens + + # If under limit, no need to trim + if current_tokens <= max_tokens: + return + + # Keep messages from newest, accumulating tokens + available_tokens = max_tokens - system_tokens + kept_messages = [] + accumulated_tokens = 0 + + for msg in reversed(self.messages): + msg_tokens = self.agent._estimate_message_tokens(msg) + if accumulated_tokens + msg_tokens <= available_tokens: + kept_messages.insert(0, msg) + accumulated_tokens += msg_tokens + else: + break + + old_count = len(self.messages) + self.messages = kept_messages + new_count = len(self.messages) + + if old_count > new_count: + logger.info( + f"Context trimmed: {old_count} -> {new_count} messages " + f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, " + f"limit: {max_tokens})" + ) + + def _prepare_messages(self) -> List[Dict[str, Any]]: + """ + Prepare messages to send to LLM + + Note: For Claude API, system prompt should be passed separately via system parameter, + not as a message. The AgentLLMModel will handle this. + """ + # Don't add system message here - it will be handled separately by the LLM adapter + return self.messages \ No newline at end of file diff --git a/agent/protocol/context.py b/agent/protocol/context.py new file mode 100644 index 0000000..1c37850 --- /dev/null +++ b/agent/protocol/context.py @@ -0,0 +1,27 @@ +class TeamContext: + def __init__(self, name: str, description: str, rule: str, agents: list, max_steps: int = 100): + """ + Initialize the TeamContext with a name, description, rules, a list of agents, and a user question. + :param name: The name of the group context. + :param description: A description of the group context. + :param rule: The rules governing the group context. + :param agents: A list of agents in the context. + """ + self.name = name + self.description = description + self.rule = rule + self.agents = agents + self.user_task = "" # For backward compatibility + self.task = None # Will be a Task instance + self.model = None # Will be an instance of LLMModel + self.task_short_name = None # Store the task directory name + # List of agents that have been executed + self.agent_outputs: list = [] + self.current_steps = 0 + self.max_steps = max_steps + + +class AgentOutput: + def __init__(self, agent_name: str, output: str): + self.agent_name = agent_name + self.output = output \ No newline at end of file diff --git a/agent/protocol/models.py b/agent/protocol/models.py new file mode 100644 index 0000000..9157211 --- /dev/null +++ b/agent/protocol/models.py @@ -0,0 +1,57 @@ +""" +Models module for agent system. +Provides basic model classes needed by tools and bridge integration. +""" + +from typing import Any, Dict, List, Optional + + +class LLMRequest: + """Request model for LLM operations""" + + def __init__(self, messages: List[Dict[str, str]] = None, model: Optional[str] = None, + temperature: float = 0.7, max_tokens: Optional[int] = None, + stream: bool = False, tools: Optional[List] = None, **kwargs): + self.messages = messages or [] + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + self.stream = stream + self.tools = tools + # Allow extra attributes + for key, value in kwargs.items(): + setattr(self, key, value) + + +class LLMModel: + """Base class for LLM models""" + + def __init__(self, model: str = None, **kwargs): + self.model = model + self.config = kwargs + + def call(self, request: LLMRequest): + """ + Call the model with a request. + This is a placeholder implementation. + """ + raise NotImplementedError("LLMModel.call not implemented in this context") + + def call_stream(self, request: LLMRequest): + """ + Call the model with streaming. + This is a placeholder implementation. + """ + raise NotImplementedError("LLMModel.call_stream not implemented in this context") + + +class ModelFactory: + """Factory for creating model instances""" + + @staticmethod + def create_model(model_type: str, **kwargs): + """ + Create a model instance based on type. + This is a placeholder implementation. + """ + raise NotImplementedError("ModelFactory.create_model not implemented in this context") \ No newline at end of file diff --git a/agent/protocol/result.py b/agent/protocol/result.py new file mode 100644 index 0000000..d097503 --- /dev/null +++ b/agent/protocol/result.py @@ -0,0 +1,96 @@ +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Dict, Any, Optional + +from agent.protocol.task import Task, TaskStatus + + +class AgentActionType(Enum): + """Enum representing different types of agent actions.""" + TOOL_USE = "tool_use" + THINKING = "thinking" + FINAL_ANSWER = "final_answer" + + +@dataclass +class ToolResult: + """ + Represents the result of a tool use. + + Attributes: + tool_name: Name of the tool used + input_params: Parameters passed to the tool + output: Output from the tool + status: Status of the tool execution (success/error) + error_message: Error message if the tool execution failed + execution_time: Time taken to execute the tool + """ + tool_name: str + input_params: Dict[str, Any] + output: Any + status: str + error_message: Optional[str] = None + execution_time: float = 0.0 + + +@dataclass +class AgentAction: + """ + Represents an action taken by an agent. + + Attributes: + id: Unique identifier for the action + agent_id: ID of the agent that performed the action + agent_name: Name of the agent that performed the action + action_type: Type of action (tool use, thinking, final answer) + content: Content of the action (thought content, final answer content) + tool_result: Tool use details if action_type is TOOL_USE + timestamp: When the action was performed + """ + agent_id: str + agent_name: str + action_type: AgentActionType + id: str = field(default_factory=lambda: str(uuid.uuid4())) + content: str = "" + tool_result: Optional[ToolResult] = None + thought: Optional[str] = None + timestamp: float = field(default_factory=time.time) + + +@dataclass +class AgentResult: + """ + Represents the result of an agent's execution. + + Attributes: + final_answer: The final answer provided by the agent + step_count: Number of steps taken by the agent + status: Status of the execution (success/error) + error_message: Error message if execution failed + """ + final_answer: str + step_count: int + status: str = "success" + error_message: Optional[str] = None + + @classmethod + def success(cls, final_answer: str, step_count: int) -> "AgentResult": + """Create a successful result""" + return cls(final_answer=final_answer, step_count=step_count) + + @classmethod + def error(cls, error_message: str, step_count: int = 0) -> "AgentResult": + """Create an error result""" + return cls( + final_answer=f"Error: {error_message}", + step_count=step_count, + status="error", + error_message=error_message + ) + + @property + def is_error(self) -> bool: + """Check if the result represents an error""" + return self.status == "error" \ No newline at end of file diff --git a/agent/protocol/task.py b/agent/protocol/task.py new file mode 100644 index 0000000..5ef0daa --- /dev/null +++ b/agent/protocol/task.py @@ -0,0 +1,95 @@ +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, Any, List + + +class TaskType(Enum): + """Enum representing different types of tasks.""" + TEXT = "text" + IMAGE = "image" + VIDEO = "video" + AUDIO = "audio" + FILE = "file" + MIXED = "mixed" + + +class TaskStatus(Enum): + """Enum representing the status of a task.""" + INIT = "init" # Initial state + PROCESSING = "processing" # In progress + COMPLETED = "completed" # Completed + FAILED = "failed" # Failed + + +@dataclass +class Task: + """ + Represents a task to be processed by an agent. + + Attributes: + id: Unique identifier for the task + content: The primary text content of the task + type: Type of the task + status: Current status of the task + created_at: Timestamp when the task was created + updated_at: Timestamp when the task was last updated + metadata: Additional metadata for the task + images: List of image URLs or base64 encoded images + videos: List of video URLs + audios: List of audio URLs or base64 encoded audios + files: List of file URLs or paths + """ + id: str = field(default_factory=lambda: str(uuid.uuid4())) + content: str = "" + type: TaskType = TaskType.TEXT + status: TaskStatus = TaskStatus.INIT + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + metadata: Dict[str, Any] = field(default_factory=dict) + + # Media content + images: List[str] = field(default_factory=list) + videos: List[str] = field(default_factory=list) + audios: List[str] = field(default_factory=list) + files: List[str] = field(default_factory=list) + + def __init__(self, content: str = "", **kwargs): + """ + Initialize a Task with content and optional keyword arguments. + + Args: + content: The text content of the task + **kwargs: Additional attributes to set + """ + self.id = kwargs.get('id', str(uuid.uuid4())) + self.content = content + self.type = kwargs.get('type', TaskType.TEXT) + self.status = kwargs.get('status', TaskStatus.INIT) + self.created_at = kwargs.get('created_at', time.time()) + self.updated_at = kwargs.get('updated_at', time.time()) + self.metadata = kwargs.get('metadata', {}) + self.images = kwargs.get('images', []) + self.videos = kwargs.get('videos', []) + self.audios = kwargs.get('audios', []) + self.files = kwargs.get('files', []) + + def get_text(self) -> str: + """ + Get the text content of the task. + + Returns: + The text content + """ + return self.content + + def update_status(self, status: TaskStatus) -> None: + """ + Update the status of the task. + + Args: + status: The new status + """ + self.status = status + self.updated_at = time.time() \ No newline at end of file diff --git a/agent/tools/__init__.py b/agent/tools/__init__.py new file mode 100644 index 0000000..00017e7 --- /dev/null +++ b/agent/tools/__init__.py @@ -0,0 +1,101 @@ +# Import base tool +from agent.tools.base_tool import BaseTool +from agent.tools.tool_manager import ToolManager + +# Import basic tools (no external dependencies) +from agent.tools.calculator.calculator import Calculator +from agent.tools.current_time.current_time import CurrentTime + +# Import file operation tools +from agent.tools.read.read import Read +from agent.tools.write.write import Write +from agent.tools.edit.edit import Edit +from agent.tools.bash.bash import Bash +from agent.tools.grep.grep import Grep +from agent.tools.find.find import Find +from agent.tools.ls.ls import Ls + +# Import memory tools +from agent.tools.memory.memory_search import MemorySearchTool +from agent.tools.memory.memory_get import MemoryGetTool + +# Import tools with optional dependencies +def _import_optional_tools(): + """Import tools that have optional dependencies""" + tools = {} + + # Google Search (requires requests) + try: + from agent.tools.google_search.google_search import GoogleSearch + tools['GoogleSearch'] = GoogleSearch + except ImportError: + pass + + # File Save (may have dependencies) + try: + from agent.tools.file_save.file_save import FileSave + tools['FileSave'] = FileSave + except ImportError: + pass + + # Terminal (basic, should work) + try: + from agent.tools.terminal.terminal import Terminal + tools['Terminal'] = Terminal + except ImportError: + pass + + return tools + +# Load optional tools +_optional_tools = _import_optional_tools() +GoogleSearch = _optional_tools.get('GoogleSearch') +FileSave = _optional_tools.get('FileSave') +Terminal = _optional_tools.get('Terminal') + + +# Delayed import for BrowserTool +def _import_browser_tool(): + try: + from agent.tools.browser.browser_tool import BrowserTool + return BrowserTool + except ImportError: + # Return a placeholder class that will prompt the user to install dependencies when instantiated + class BrowserToolPlaceholder: + def __init__(self, *args, **kwargs): + raise ImportError( + "The 'browser-use' package is required to use BrowserTool. " + "Please install it with 'pip install browser-use>=0.1.40'." + ) + + return BrowserToolPlaceholder + + +# Dynamically set BrowserTool +BrowserTool = _import_browser_tool() + +# Export all tools (including optional ones that might be None) +__all__ = [ + 'BaseTool', + 'ToolManager', + 'Calculator', + 'CurrentTime', + 'Read', + 'Write', + 'Edit', + 'Bash', + 'Grep', + 'Find', + 'Ls', + 'MemorySearchTool', + 'MemoryGetTool', + # Optional tools (may be None if dependencies not available) + 'GoogleSearch', + 'FileSave', + 'Terminal', + 'BrowserTool' +] + +""" +Tools module for Agent. +""" diff --git a/agent/tools/base_tool.py b/agent/tools/base_tool.py new file mode 100644 index 0000000..a3ca262 --- /dev/null +++ b/agent/tools/base_tool.py @@ -0,0 +1,99 @@ +from enum import Enum +from typing import Any, Optional +from common.log import logger +import copy + + +class ToolStage(Enum): + """Enum representing tool decision stages""" + PRE_PROCESS = "pre_process" # Tools that need to be actively selected by the agent + POST_PROCESS = "post_process" # Tools that automatically execute after final_answer + + +class ToolResult: + """Tool execution result""" + + def __init__(self, status: str = None, result: Any = None, ext_data: Any = None): + self.status = status + self.result = result + self.ext_data = ext_data + + @staticmethod + def success(result, ext_data: Any = None): + return ToolResult(status="success", result=result, ext_data=ext_data) + + @staticmethod + def fail(result, ext_data: Any = None): + return ToolResult(status="error", result=result, ext_data=ext_data) + + +class BaseTool: + """Base class for all tools.""" + + # Default decision stage is pre-process + stage = ToolStage.PRE_PROCESS + + # Class attributes must be inherited + name: str = "base_tool" + description: str = "Base tool" + params: dict = {} # Store JSON Schema + model: Optional[Any] = None # LLM model instance, type depends on bot implementation + + @classmethod + def get_json_schema(cls) -> dict: + """Get the standard description of the tool""" + return { + "name": cls.name, + "description": cls.description, + "parameters": cls.params + } + + def execute_tool(self, params: dict) -> ToolResult: + try: + return self.execute(params) + except Exception as e: + logger.error(e) + + def execute(self, params: dict) -> ToolResult: + """Specific logic to be implemented by subclasses""" + raise NotImplementedError + + @classmethod + def _parse_schema(cls) -> dict: + """Convert JSON Schema to Pydantic fields""" + fields = {} + for name, prop in cls.params["properties"].items(): + # Convert JSON Schema types to Python types + type_map = { + "string": str, + "number": float, + "integer": int, + "boolean": bool, + "array": list, + "object": dict + } + fields[name] = ( + type_map[prop["type"]], + prop.get("default", ...) + ) + return fields + + def should_auto_execute(self, context) -> bool: + """ + Determine if this tool should be automatically executed based on context. + + :param context: The agent context + :return: True if the tool should be executed, False otherwise + """ + # Only tools in post-process stage will be automatically executed + return self.stage == ToolStage.POST_PROCESS + + def close(self): + """ + Close any resources used by the tool. + This method should be overridden by tools that need to clean up resources + such as browser connections, file handles, etc. + + By default, this method does nothing. + """ + pass diff --git a/agent/tools/bash/__init__.py b/agent/tools/bash/__init__.py new file mode 100644 index 0000000..bbd4bb0 --- /dev/null +++ b/agent/tools/bash/__init__.py @@ -0,0 +1,3 @@ +from .bash import Bash + +__all__ = ['Bash'] diff --git a/agent/tools/bash/bash.py b/agent/tools/bash/bash.py new file mode 100644 index 0000000..d044ba5 --- /dev/null +++ b/agent/tools/bash/bash.py @@ -0,0 +1,187 @@ +""" +Bash tool - Execute bash commands +""" + +import os +import subprocess +import tempfile +from typing import Dict, Any + +from agent.tools.base_tool import BaseTool, ToolResult +from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES + + +class Bash(BaseTool): + """Tool for executing bash commands""" + + name: str = "bash" + description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file. + +IMPORTANT SAFETY GUIDELINES: +- You can freely create, modify, and delete files within the current workspace +- For operations outside the workspace or potentially destructive commands (rm -rf, system commands, etc.), always explain what you're about to do and ask for user confirmation first +- Be especially careful with: file deletions, system modifications, network operations, or commands that might affect system stability +- When in doubt, describe the command's purpose and ask for permission before executing""" + + params: dict = { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Bash command to execute" + }, + "timeout": { + "type": "integer", + "description": "Timeout in seconds (optional, default: 30)" + } + }, + "required": ["command"] + } + + def __init__(self, config: dict = None): + self.config = config or {} + self.cwd = self.config.get("cwd", os.getcwd()) + # Ensure working directory exists + if not os.path.exists(self.cwd): + os.makedirs(self.cwd, exist_ok=True) + self.default_timeout = self.config.get("timeout", 30) + # Enable safety mode by default (can be disabled in config) + self.safety_mode = self.config.get("safety_mode", True) + + def execute(self, args: Dict[str, Any]) -> ToolResult: + """ + Execute a bash command + + :param args: Dictionary containing the command and optional timeout + :return: Command output or error + """ + command = args.get("command", "").strip() + timeout = args.get("timeout", self.default_timeout) + + if not command: + return ToolResult.fail("Error: command parameter is required") + + # Optional safety check - only warn about extremely dangerous commands + if self.safety_mode: + warning = self._get_safety_warning(command) + if warning: + return ToolResult.fail( + f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.") + + try: + # Execute command + result = subprocess.run( + command, + shell=True, + cwd=self.cwd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=timeout + ) + + # Combine stdout and stderr + output = result.stdout + if result.stderr: + output += "\n" + result.stderr + + # Check if we need to save full output to temp file + temp_file_path = None + total_bytes = len(output.encode('utf-8')) + + if total_bytes > DEFAULT_MAX_BYTES: + # Save full output to temp file + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log', prefix='bash-') as f: + f.write(output) + temp_file_path = f.name + + # Apply tail truncation + truncation = truncate_tail(output) + output_text = truncation.content or "(no output)" + + # Build result + details = {} + + if truncation.truncated: + details["truncation"] = truncation.to_dict() + if temp_file_path: + details["full_output_path"] = temp_file_path + + # Build notice + start_line = truncation.total_lines - truncation.output_lines + 1 + end_line = truncation.total_lines + + if truncation.last_line_partial: + # Edge case: last line alone > 30KB + last_line = output.split('\n')[-1] if output else "" + last_line_size = format_size(len(last_line.encode('utf-8'))) + output_text += f"\n\n[Showing last {format_size(truncation.output_bytes)} of line {end_line} (line is {last_line_size}). Full output: {temp_file_path}]" + elif truncation.truncated_by == "lines": + output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines}. Full output: {temp_file_path}]" + else: + output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Full output: {temp_file_path}]" + + # Check exit code + if result.returncode != 0: + output_text += f"\n\nCommand exited with code {result.returncode}" + return ToolResult.fail({ + "output": output_text, + "exit_code": result.returncode, + "details": details if details else None + }) + + return ToolResult.success({ + "output": output_text, + "exit_code": result.returncode, + "details": details if details else None + }) + + except subprocess.TimeoutExpired: + return ToolResult.fail(f"Error: Command timed out after {timeout} seconds") + except Exception as e: + return ToolResult.fail(f"Error executing command: {str(e)}") + + def _get_safety_warning(self, command: str) -> str: + """ + Get safety warning for potentially dangerous commands + Only warns about extremely dangerous system-level operations + + :param command: Command to check + :return: Warning message if dangerous, empty string if safe + """ + cmd_lower = command.lower().strip() + + # Only block extremely dangerous system operations + dangerous_patterns = [ + # System shutdown/reboot + ("shutdown", "This command will shut down the system"), + ("reboot", "This command will reboot the system"), + ("halt", "This command will halt the system"), + ("poweroff", "This command will power off the system"), + + # Critical system modifications + ("rm -rf /", "This command will delete the entire filesystem"), + ("rm -rf /*", "This command will delete the entire filesystem"), + ("dd if=/dev/zero", "This command can destroy disk data"), + ("mkfs", "This command will format a filesystem, destroying all data"), + ("fdisk", "This command modifies disk partitions"), + + # User/system management (only if targeting system users) + ("userdel root", "This command will delete the root user"), + ("passwd root", "This command will change the root password"), + ] + + for pattern, warning in dangerous_patterns: + if pattern in cmd_lower: + return warning + + # Check for recursive deletion outside workspace + if "rm" in cmd_lower and "-rf" in cmd_lower: + # Allow deletion within current workspace + if not any(path in cmd_lower for path in ["./", self.cwd.lower()]): + # Check if targeting system directories + system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"] + if any(sysdir in cmd_lower for sysdir in system_dirs): + return "This command will recursively delete system directories" + + return "" # No warning needed diff --git a/agent/tools/browser/browser_action.py b/agent/tools/browser/browser_action.py new file mode 100644 index 0000000..f5d69d5 --- /dev/null +++ b/agent/tools/browser/browser_action.py @@ -0,0 +1,59 @@ +class BrowserAction: + """Base class for browser actions""" + code = "" + description = "" + + +class Navigate(BrowserAction): + """Navigate to a URL in the current tab""" + code = "navigate" + description = "Navigate to URL in the current tab" + + +class ClickElement(BrowserAction): + """Click an element on the page""" + code = "click_element" + description = "Click element" + + +class ExtractContent(BrowserAction): + """Extract content from the page""" + code = "extract_content" + description = "Extract the page content to retrieve specific information for a goal" + + +class InputText(BrowserAction): + """Input text into an element""" + code = "input_text" + description = "Input text into a input interactive element" + + +class ScrollDown(BrowserAction): + """Scroll down the page""" + code = "scroll_down" + description = "Scroll down the page by pixel amount" + + +class ScrollUp(BrowserAction): + """Scroll up the page""" + code = "scroll_up" + description = "Scroll up the page by pixel amount - if no amount is specified, scroll up one page" + + +class OpenTab(BrowserAction): + """Open a URL in a new tab""" + code = "open_tab" + description = "Open url in new tab" + + +class SwitchTab(BrowserAction): + """Switch to a tab""" + code = "switch_tab" + description = "Switched to tab" + + +class SendKeys(BrowserAction): + """Switch to a tab""" + code = "send_keys" + description = "Send strings of special keyboard keys like Escape, Backspace, Insert, PageDown, Delete, Enter, " \ + "ArrowRight, ArrowUp, etc" diff --git a/agent/tools/browser/browser_tool.py b/agent/tools/browser/browser_tool.py new file mode 100644 index 0000000..4028b57 --- /dev/null +++ b/agent/tools/browser/browser_tool.py @@ -0,0 +1,317 @@ +import asyncio +from typing import Any, Dict +import json +import re +import os +import platform +from browser_use import Browser +from browser_use import BrowserConfig +from browser_use.browser.context import BrowserContext, BrowserContextConfig +from agent.tools.base_tool import BaseTool, ToolResult +from agent.tools.browser.browser_action import * +from agent.models import LLMRequest +from agent.models.model_factory import ModelFactory +from browser_use.dom.service import DomService +from common.log import logger + + +# Use lazy import, only import when actually used +def _import_browser_use(): + try: + import browser_use + return browser_use + except ImportError: + raise ImportError( + "The 'browser-use' package is required to use BrowserTool. " + "Please install it with 'pip install browser-use>=0.1.40' or " + "'pip install agentmesh-sdk[full]'." + ) + + +def _get_action_prompt(): + action_classes = [Navigate, ClickElement, ExtractContent, InputText, OpenTab, SwitchTab, ScrollDown, ScrollUp, + SendKeys] + action_prompt = "" + for action_class in action_classes: + action_prompt += f"{action_class.code}: {action_class.description}\n" + return action_prompt.strip() + + +def _header_less() -> bool: + if platform.system() == "Linux" and not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"): + return True + return False + + +class BrowserTool(BaseTool): + name: str = "browser" + description: str = "A tool to perform browser operations like navigating to URLs, element interaction, " \ + "and extracting content." + params: dict = { + "type": "object", + "properties": { + "operation": { + "type": "string", + "description": f"The browser operation to perform: \n{_get_action_prompt()}" + }, + "url": { + "type": "string", + "description": f"The URL to navigate to (required for '{Navigate.code}', '{OpenTab.code}' actions). " + }, + "goal": { + "type": "string", + "description": f"The goal of extracting page content (required for '{ExtractContent.code}' action)." + }, + "text": { + "type": "string", + "description": f"Text to type (required for '{InputText.code}' action)." + }, + "index": { + "type": "integer", + "description": f"Element index (required for '{ClickElement.code}', '{InputText.code}' actions)", + }, + "tab_id": { + "type": "integer", + "description": f"Page tab ID (required for '{SwitchTab.code}' action)", + }, + "scroll_amount": { + "type": "integer", + "description": f"The number of pixels to scroll (required for '{ScrollDown.code}', '{ScrollUp.code}' action)." + }, + "keys": { + "type": "string", + "description": f"Keys to send (required for '{SendKeys.code}' action)" + } + }, + "required": ["operation"] + } + + # Class variable to ensure only one browser instance is created + browser = None + browser_context: BrowserContext = None + dom_service: DomService = None + _initialized = False + + # Adding an event loop variable + _event_loop = None + + def __init__(self): + # Only import during initialization, not at module level + self.browser_use = _import_browser_use() + # Do not initialize the browser in the constructor, but initialize it on the first execution + pass + + async def _init_browser(self) -> BrowserContext: + """Ensure the browser is initialized""" + if not BrowserTool._initialized: + os.environ['BROWSER_USE_LOGGING_LEVEL'] = 'error' + print("Initializing browser...") + # Initialize the browser synchronously + BrowserTool.browser = Browser(BrowserConfig(headless=_header_less(), + disable_security=True)) + context_config = BrowserContextConfig() + context_config.highlight_elements = True + BrowserTool.browser_context = await BrowserTool.browser.new_context(context_config) + BrowserTool._initialized = True + print("Browser initialized successfully") + BrowserTool.dom_service = DomService(await BrowserTool.browser_context.get_current_page()) + return BrowserTool.browser_context + + def execute(self, params: Dict[str, Any]) -> ToolResult: + """ + Execute browser operations based on the provided arguments. + + :param params: Dictionary containing the action and related parameters + :return: Result of the browser operation + """ + # Ensure browser_use is imported + if not hasattr(self, 'browser_use'): + self.browser_use = _import_browser_use() + action = params.get("operation", "").lower() + + try: + # Use a single event loop + if BrowserTool._event_loop is None: + BrowserTool._event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(BrowserTool._event_loop) + # Run tasks in the existing event loop + return BrowserTool._event_loop.run_until_complete(self._execute_async(action, params)) + except Exception as e: + print(f"Error executing browser action: {e}") + return ToolResult.fail(result=f"Error executing browser action: {str(e)}") + + async def _get_page_state(self, context: BrowserContext): + state = await self._get_state(context) + include_attributes = ["img", "div", "button", "input"] + elements = state.element_tree.clickable_elements_to_string(include_attributes) + pattern = r'\[\d+\]<[^>]+\/>' + # Find all matching elements + interactive_elements = re.findall(pattern, elements) + page_state = { + "url": state.url, + "title": state.title, + "pixels_above": getattr(state, "pixels_above", 0), + "pixels_below": getattr(state, "pixels_below", 0), + "tabs": [tab.model_dump() for tab in state.tabs], + "interactive_elements": interactive_elements, + } + return page_state + + async def _get_state(self, context: BrowserContext, cache_clickable_elements_hashes=True): + try: + return await context.get_state() + except TypeError: + return await context.get_state(cache_clickable_elements_hashes=cache_clickable_elements_hashes) + + async def _get_page_info(self, context: BrowserContext): + page_state = await self._get_page_state(context) + state_str = f"""## Current browser state +The following is the information of the current browser page. Each serial number in interactive_elements represents the element index: +{json.dumps(page_state, indent=4, ensure_ascii=False)} +""" + return state_str + + async def _execute_async(self, action: str, params: Dict[str, Any]) -> ToolResult: + """Asynchronously execute browser operations""" + # Use the browser context from the class variable + context = await self._init_browser() + + if action == Navigate.code: + url = params.get("url") + if not url: + return ToolResult.fail(result="URL is required for navigate action") + if url.startswith("/"): + url = f"file://{url}" + print(f"Navigating to {url}...") + page = await context.get_current_page() + await page.goto(url) + await page.wait_for_load_state() + state = await self._get_page_info(context) + # print(state) + print(f"Navigation complete") + return ToolResult.success(result=f"Navigated to {url}", ext_data=state) + + elif action == OpenTab.code: + url = params.get("url") + if url.startswith("/"): + url = f"file://{url}" + await context.create_new_tab(url) + msg = f"Opened new tab with {url}" + return ToolResult.success(result=msg) + + elif action == ExtractContent.code: + try: + goal = params.get("goal") + page = await context.get_current_page() + if params.get("url"): + await page.goto(params.get("url")) + await page.wait_for_load_state() + import markdownify + content = markdownify.markdownify(await page.content()) + elements = await self._get_page_state(context) + prompt = f"Your task is to extract the content of the page. You will be given a page and a goal and you should extract all relevant information around this goal from the page. If the goal is vague, " \ + f"summarize the page. Respond in json format. elements: {elements.get('interactive_elements')}, extraction goal: {goal}, Page: {content}," + request = LLMRequest( + messages=[{"role": "user", "content": prompt}], + temperature=0, + json_format=True + ) + model = self.model or ModelFactory().get_model(model_name="gpt-4o") + response = model.call(request) + if response.success: + extract_content = response.data["choices"][0]["message"]["content"] + print(f"Extract from page: {extract_content}") + return ToolResult.success(result=f"Extract from page: {extract_content}", + ext_data=await self._get_page_info(context)) + else: + return ToolResult.fail(result=f"Extract from page failed: {response.get_error_msg()}") + except Exception as e: + logger.error(e) + + elif action == ClickElement.code: + index = params.get("index") + element = await context.get_dom_element_by_index(index) + await context._click_element_node(element) + msg = f"Clicked element at index {index}" + print(msg) + return ToolResult.success(result=msg, ext_data=await self._get_page_info(context)) + + elif action == InputText.code: + index = params.get("index") + text = params.get("text") + element = await context.get_dom_element_by_index(index) + await context._input_text_element_node(element, text) + await asyncio.sleep(1) + msg = f"Input text into element successfully, index: {index}, text: {text}" + return ToolResult.success(result=msg, ext_data=await self._get_page_info(context)) + + elif action == SwitchTab.code: + tab_id = params.get("tab_id") + print(f"Switch tab, tab_id={tab_id}") + await context.switch_to_tab(tab_id) + page = await context.get_current_page() + await page.wait_for_load_state() + msg = f"Switched to tab {tab_id}" + return ToolResult.success(result=msg, ext_data=await self._get_page_info(context)) + + elif action in [ScrollDown.code, ScrollUp.code]: + scroll_amount = params.get("scroll_amount") + if not scroll_amount: + scroll_amount = context.config.browser_window_size["height"] + print(f"Scrolling by {scroll_amount} pixels") + scroll_amount = scroll_amount if action == ScrollDown.code else (scroll_amount * -1) + await context.execute_javascript(f"window.scrollBy(0, {scroll_amount});") + msg = f"{action} by {scroll_amount} pixels" + return ToolResult.success(result=msg, ext_data=await self._get_page_info(context)) + + elif action == SendKeys.code: + keys = params.get("keys") + page = await context.get_current_page() + await page.keyboard.press(keys) + msg = f"Sent keys: {keys}" + print(msg) + return ToolResult(output=f"Sent keys: {keys}") + + else: + msg = "Failed to operate the browser" + return ToolResult.fail(result=msg) + + def close(self): + """ + Close browser resources. + This method handles the asynchronous closing of browser and browser context. + """ + if not BrowserTool._initialized: + return + + try: + # Use the existing event loop to close browser resources + if BrowserTool._event_loop is not None: + # Define the async close function + async def close_browser_async(): + if BrowserTool.browser_context is not None: + try: + await BrowserTool.browser_context.close() + except Exception as e: + logger.error(f"Error closing browser context: {e}") + + if BrowserTool.browser is not None: + try: + await BrowserTool.browser.close() + except Exception as e: + logger.error(f"Error closing browser: {e}") + + # Reset the initialized flag + BrowserTool._initialized = False + BrowserTool.browser = None + BrowserTool.browser_context = None + BrowserTool.dom_service = None + + # Run the async close function in the existing event loop + BrowserTool._event_loop.run_until_complete(close_browser_async()) + + # Close the event loop + BrowserTool._event_loop.close() + BrowserTool._event_loop = None + except Exception as e: + print(f"Error during browser cleanup: {e}") diff --git a/agent/tools/browser_tool.py b/agent/tools/browser_tool.py new file mode 100644 index 0000000..b134ef7 --- /dev/null +++ b/agent/tools/browser_tool.py @@ -0,0 +1,18 @@ +def copy(self): + """ + Special copy method for browser tool to avoid recreating browser instance. + + :return: A new instance with shared browser reference but unique model + """ + new_tool = self.__class__() + + # Copy essential attributes + new_tool.model = self.model + new_tool.context = getattr(self, 'context', None) + new_tool.config = getattr(self, 'config', None) + + # Share the browser instance instead of creating a new one + if hasattr(self, 'browser'): + new_tool.browser = self.browser + + return new_tool \ No newline at end of file diff --git a/agent/tools/calculator/calculator.py b/agent/tools/calculator/calculator.py new file mode 100644 index 0000000..092343d --- /dev/null +++ b/agent/tools/calculator/calculator.py @@ -0,0 +1,58 @@ +import math + +from agent.tools.base_tool import BaseTool, ToolResult + + +class Calculator(BaseTool): + name: str = "calculator" + description: str = "A tool to perform basic mathematical calculations." + params: dict = { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). " + "Ensure your input is a valid Python expression, it will be evaluated directly." + } + }, + "required": ["expression"] + } + config: dict = {} + + def execute(self, args: dict) -> ToolResult: + try: + # Get the expression + expression = args["expression"] + + # Create a safe local environment containing only basic math functions + safe_locals = { + "abs": abs, + "round": round, + "max": max, + "min": min, + "pow": pow, + "sqrt": math.sqrt, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "pi": math.pi, + "e": math.e, + "log": math.log, + "log10": math.log10, + "exp": math.exp, + "floor": math.floor, + "ceil": math.ceil + } + + # Safely evaluate the expression + result = eval(expression, {"__builtins__": {}}, safe_locals) + + return ToolResult.success({ + "result": result, + "expression": expression + }) + except Exception as e: + return ToolResult.success({ + "error": str(e), + "expression": args.get("expression", "") + }) diff --git a/agent/tools/current_time/current_time.py b/agent/tools/current_time/current_time.py new file mode 100644 index 0000000..5fb0f95 --- /dev/null +++ b/agent/tools/current_time/current_time.py @@ -0,0 +1,75 @@ +import datetime +import time + +from agent.tools.base_tool import BaseTool, ToolResult + + +class CurrentTime(BaseTool): + name: str = "time" + description: str = "A tool to get current date and time information." + params: dict = { + "type": "object", + "properties": { + "format": { + "type": "string", + "description": "Optional format for the time (e.g., 'iso', 'unix', 'human'). Default is 'human'." + }, + "timezone": { + "type": "string", + "description": "Optional timezone specification (e.g., 'UTC', 'local'). Default is 'local'." + } + }, + "required": [] + } + config: dict = {} + + def execute(self, args: dict) -> ToolResult: + try: + # Get the format and timezone parameters, with defaults + time_format = args.get("format", "human").lower() + timezone = args.get("timezone", "local").lower() + + # Get current time + current_time = datetime.datetime.now() + + # Handle timezone if specified + if timezone == "utc": + current_time = datetime.datetime.utcnow() + + # Format the time according to the specified format + if time_format == "iso": + # ISO 8601 format + formatted_time = current_time.isoformat() + elif time_format == "unix": + # Unix timestamp (seconds since epoch) + formatted_time = time.time() + else: + # Human-readable format + formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S") + + # Prepare additional time components for the response + year = current_time.year + month = current_time.month + day = current_time.day + hour = current_time.hour + minute = current_time.minute + second = current_time.second + weekday = current_time.strftime("%A") # Full weekday name + + result = { + "current_time": formatted_time, + "components": { + "year": year, + "month": month, + "day": day, + "hour": hour, + "minute": minute, + "second": second, + "weekday": weekday + }, + "format": time_format, + "timezone": timezone + } + return ToolResult.success(result=result) + except Exception as e: + return ToolResult.fail(result=str(e)) diff --git a/agent/tools/edit/__init__.py b/agent/tools/edit/__init__.py new file mode 100644 index 0000000..68b84bb --- /dev/null +++ b/agent/tools/edit/__init__.py @@ -0,0 +1,3 @@ +from .edit import Edit + +__all__ = ['Edit'] diff --git a/agent/tools/edit/edit.py b/agent/tools/edit/edit.py new file mode 100644 index 0000000..54f7529 --- /dev/null +++ b/agent/tools/edit/edit.py @@ -0,0 +1,164 @@ +""" +Edit tool - Precise file editing +Edit files through exact text replacement +""" + +import os +from typing import Dict, Any + +from agent.tools.base_tool import BaseTool, ToolResult +from agent.tools.utils.diff import ( + strip_bom, + detect_line_ending, + normalize_to_lf, + restore_line_endings, + normalize_for_fuzzy_match, + fuzzy_find_text, + generate_diff_string +) + + +class Edit(BaseTool): + """Tool for precise file editing""" + + name: str = "edit" + description: str = "Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits." + + params: dict = { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file to edit (relative or absolute)" + }, + "oldText": { + "type": "string", + "description": "Exact text to find and replace (must match exactly)" + }, + "newText": { + "type": "string", + "description": "New text to replace the old text with" + } + }, + "required": ["path", "oldText", "newText"] + } + + def __init__(self, config: dict = None): + self.config = config or {} + self.cwd = self.config.get("cwd", os.getcwd()) + + def execute(self, args: Dict[str, Any]) -> ToolResult: + """ + Execute file edit operation + + :param args: Contains file path, old text and new text + :return: Operation result + """ + path = args.get("path", "").strip() + old_text = args.get("oldText", "") + new_text = args.get("newText", "") + + if not path: + return ToolResult.fail("Error: path parameter is required") + + # Resolve path + absolute_path = self._resolve_path(path) + + # Check if file exists + if not os.path.exists(absolute_path): + return ToolResult.fail(f"Error: File not found: {path}") + + # Check if readable/writable + if not os.access(absolute_path, os.R_OK | os.W_OK): + return ToolResult.fail(f"Error: File is not readable/writable: {path}") + + try: + # Read file + with open(absolute_path, 'r', encoding='utf-8') as f: + raw_content = f.read() + + # Remove BOM (LLM won't include invisible BOM in oldText) + bom, content = strip_bom(raw_content) + + # Detect original line ending + original_ending = detect_line_ending(content) + + # Normalize to LF + normalized_content = normalize_to_lf(content) + normalized_old_text = normalize_to_lf(old_text) + normalized_new_text = normalize_to_lf(new_text) + + # Use fuzzy matching to find old text (try exact match first, then fuzzy match) + match_result = fuzzy_find_text(normalized_content, normalized_old_text) + + if not match_result.found: + return ToolResult.fail( + f"Error: Could not find the exact text in {path}. " + "The old text must match exactly including all whitespace and newlines." + ) + + # Calculate occurrence count (use fuzzy normalized content for consistency) + fuzzy_content = normalize_for_fuzzy_match(normalized_content) + fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text) + occurrences = fuzzy_content.count(fuzzy_old_text) + + if occurrences > 1: + return ToolResult.fail( + f"Error: Found {occurrences} occurrences of the text in {path}. " + "The text must be unique. Please provide more context to make it unique." + ) + + # Execute replacement (use matched text position) + base_content = match_result.content_for_replacement + new_content = ( + base_content[:match_result.index] + + normalized_new_text + + base_content[match_result.index + match_result.match_length:] + ) + + # Verify replacement actually changed content + if base_content == new_content: + return ToolResult.fail( + f"Error: No changes made to {path}. " + "The replacement produced identical content. " + "This might indicate an issue with special characters or the text not existing as expected." + ) + + # Restore original line endings + final_content = bom + restore_line_endings(new_content, original_ending) + + # Write file + with open(absolute_path, 'w', encoding='utf-8') as f: + f.write(final_content) + + # Generate diff + diff_result = generate_diff_string(base_content, new_content) + + result = { + "message": f"Successfully replaced text in {path}", + "path": path, + "diff": diff_result['diff'], + "first_changed_line": diff_result['first_changed_line'] + } + + return ToolResult.success(result) + + except UnicodeDecodeError: + return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {path}") + except PermissionError: + return ToolResult.fail(f"Error: Permission denied accessing {path}") + except Exception as e: + return ToolResult.fail(f"Error editing file: {str(e)}") + + def _resolve_path(self, path: str) -> str: + """ + Resolve path to absolute path + + :param path: Relative or absolute path + :return: Absolute path + """ + # Expand ~ to user home directory + path = os.path.expanduser(path) + if os.path.isabs(path): + return path + return os.path.abspath(os.path.join(self.cwd, path)) diff --git a/agent/tools/file_save/__init__.py b/agent/tools/file_save/__init__.py new file mode 100644 index 0000000..d00d726 --- /dev/null +++ b/agent/tools/file_save/__init__.py @@ -0,0 +1,3 @@ +from .file_save import FileSave + +__all__ = ['FileSave'] diff --git a/agent/tools/file_save/file_save.py b/agent/tools/file_save/file_save.py new file mode 100644 index 0000000..5e2cdae --- /dev/null +++ b/agent/tools/file_save/file_save.py @@ -0,0 +1,770 @@ +import os +import time +import re +import json +from pathlib import Path +from typing import Dict, Any, Optional, Tuple + +from agent.tools.base_tool import BaseTool, ToolResult, ToolStage +from agent.models import LLMRequest +from common.log import logger + + +class FileSave(BaseTool): + """Tool for saving content to files in the workspace directory.""" + + name = "file_save" + description = "Save the agent's output to a file in the workspace directory. Content is automatically extracted from the agent's previous outputs." + + # Set as post-process stage tool + stage = ToolStage.POST_PROCESS + + params = { + "type": "object", + "properties": { + "file_name": { + "type": "string", + "description": "Optional. The name of the file to save. If not provided, a name will be generated based on the content." + }, + "file_type": { + "type": "string", + "description": "Optional. The type/extension of the file (e.g., 'txt', 'md', 'py', 'java'). If not provided, it will be inferred from the content." + }, + "extract_code": { + "type": "boolean", + "description": "Optional. If true, will attempt to extract code blocks from the content. Default is false." + } + }, + "required": [] # No required fields, as everything can be extracted from context + } + + def __init__(self): + self.context = None + self.config = {} + self.workspace_dir = Path("workspace") + + def execute(self, params: Dict[str, Any]) -> ToolResult: + """ + Save content to a file in the workspace directory. + + :param params: The parameters for the file output operation. + :return: Result of the operation. + """ + # Extract content from context + if not hasattr(self, 'context') or not self.context: + return ToolResult.fail("Error: No context available to extract content from.") + + content = self._extract_content_from_context() + + # If no content could be extracted, return error + if not content: + return ToolResult.fail("Error: Couldn't extract content from context.") + + # Use model to determine file parameters + try: + task_dir = self._get_task_dir_from_context() + file_name, file_type, extract_code = self._get_file_params_from_model(content) + except Exception as e: + logger.error(f"Error determining file parameters: {str(e)}") + # Fall back to manual parameter extraction + task_dir = params.get("task_dir") or self._get_task_id_from_context() or f"task_{int(time.time())}" + file_name = params.get("file_name") or self._infer_file_name(content) + file_type = params.get("file_type") or self._infer_file_type(content) + extract_code = params.get("extract_code", False) + + # Get team_name from context + team_name = self._get_team_name_from_context() or "default_team" + + # Create directory structure + task_dir_path = self.workspace_dir / team_name / task_dir + task_dir_path.mkdir(parents=True, exist_ok=True) + + if extract_code: + # Save the complete content as markdown + md_file_name = f"{file_name}.md" + md_file_path = task_dir_path / md_file_name + + # Write content to file + with open(md_file_path, 'w', encoding='utf-8') as f: + f.write(content) + + return self._handle_multiple_code_blocks(content) + + # Ensure file_name has the correct extension + if file_type and not file_name.endswith(f".{file_type}"): + file_name = f"{file_name}.{file_type}" + + # Create the full file path + file_path = task_dir_path / file_name + + # Get absolute path for storage in team_context + abs_file_path = file_path.absolute() + + try: + # Write content to file + with open(file_path, 'w', encoding='utf-8') as f: + f.write(content) + + # Update the current agent's final_answer to include file information + if hasattr(self.context, 'team_context'): + # Store with absolute path in team_context + self.context.team_context.agent_outputs[-1].output += f"\n\nSaved file: {abs_file_path}" + + return ToolResult.success({ + "status": "success", + "file_path": str(file_path) # Return relative path in result + }) + + except Exception as e: + return ToolResult.fail(f"Error saving file: {str(e)}") + + def _handle_multiple_code_blocks(self, content: str) -> ToolResult: + """ + Handle content with multiple code blocks, extracting and saving each as a separate file. + + :param content: The content containing multiple code blocks + :return: Result of the operation + """ + # Extract code blocks with context (including potential file name information) + code_blocks_with_context = self._extract_code_blocks_with_context(content) + + if not code_blocks_with_context: + return ToolResult.fail("No code blocks found in the content.") + + # Get task directory and team name + task_dir = self._get_task_dir_from_context() or f"task_{int(time.time())}" + team_name = self._get_team_name_from_context() or "default_team" + + # Create directory structure + task_dir_path = self.workspace_dir / team_name / task_dir + task_dir_path.mkdir(parents=True, exist_ok=True) + + saved_files = [] + + for block_with_context in code_blocks_with_context: + try: + # Use model to determine file name for this code block + block_file_name, block_file_type = self._get_filename_for_code_block(block_with_context) + + # Clean the code block (remove md code markers) + clean_code = self._clean_code_block(block_with_context) + + # Ensure file_name has the correct extension + if block_file_type and not block_file_name.endswith(f".{block_file_type}"): + block_file_name = f"{block_file_name}.{block_file_type}" + + # Create the full file path (no subdirectories) + file_path = task_dir_path / block_file_name + + # Get absolute path for storage in team_context + abs_file_path = file_path.absolute() + + # Write content to file + with open(file_path, 'w', encoding='utf-8') as f: + f.write(clean_code) + + saved_files.append({ + "file_path": str(file_path), + "abs_file_path": str(abs_file_path), # Store absolute path for internal use + "file_name": block_file_name, + "size": len(clean_code), + "status": "success", + "type": "code" + }) + + except Exception as e: + logger.error(f"Error saving code block: {str(e)}") + # Continue with the next block even if this one fails + + if not saved_files: + return ToolResult.fail("Failed to save any code blocks.") + + # Update the current agent's final_answer to include files information + if hasattr(self, 'context') and self.context: + # If the agent has a final_answer attribute, append the files info to it + if hasattr(self.context, 'team_context'): + # Use relative paths for display + display_info = f"\n\nSaved files to {task_dir_path}:\n" + "\n".join( + [f"- {f['file_path']}" for f in saved_files]) + + # Check if we need to append the info + if not self.context.team_context.agent_outputs[-1].output.endswith(display_info): + # Store with absolute paths in team_context + abs_info = f"\n\nSaved files to {task_dir_path.absolute()}:\n" + "\n".join( + [f"- {f['abs_file_path']}" for f in saved_files]) + self.context.team_context.agent_outputs[-1].output += abs_info + + result = { + "status": "success", + "files": [{"file_path": f["file_path"]} for f in saved_files] + } + + return ToolResult.success(result) + + def _extract_code_blocks_with_context(self, content: str) -> list: + """ + Extract code blocks from content, including context lines before the block. + + :param content: The content to extract code blocks from + :return: List of code blocks with context + """ + # Check if content starts with Tuple[str, str]: + """ + Determine the file name for a code block. + + :param block_with_context: The code block with context lines + :return: Tuple of (file_name, file_type) + """ + # Define common code file extensions + COMMON_CODE_EXTENSIONS = { + 'py', 'js', 'java', 'c', 'cpp', 'h', 'hpp', 'cs', 'go', 'rb', 'php', + 'html', 'css', 'ts', 'jsx', 'tsx', 'vue', 'sh', 'sql', 'json', 'xml', + 'yaml', 'yml', 'md', 'rs', 'swift', 'kt', 'scala', 'pl', 'r', 'lua' + } + + # Split the block into lines to examine only the context around code block markers + lines = block_with_context.split('\n') + + # Find the code block start marker line index + start_marker_idx = -1 + for i, line in enumerate(lines): + if line.strip().startswith('```') and not line.strip() == '```': + start_marker_idx = i + break + + if start_marker_idx == -1: + # No code block marker found + return "", "" + + # Extract the language from the code block marker + code_marker = lines[start_marker_idx].strip() + language = "" + if len(code_marker) > 3: + language = code_marker[3:].strip().split('=')[0].strip() + + # Define the context range (5 lines before and 2 after the marker) + context_start = max(0, start_marker_idx - 5) + context_end = min(len(lines), start_marker_idx + 3) + + # Extract only the relevant context lines + context_lines = lines[context_start:context_end] + + # First, check for explicit file headers like "## filename.ext" + for line in context_lines: + # Match patterns like "## filename.ext" or "# filename.ext" + header_match = re.search(r'^\s*#{1,6}\s+([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)\s*$', line) + if header_match: + file_name = header_match.group(1) + file_type = os.path.splitext(file_name)[1].lstrip('.') + if file_type in COMMON_CODE_EXTENSIONS: + return os.path.splitext(file_name)[0], file_type + + # Simple patterns to match explicit file names in the context + file_patterns = [ + # Match explicit file names in headers or text + r'(?:file|filename)[:=\s]+[\'"]?([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)[\'"]?', + # Match language=filename.ext in code markers + r'language=([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)', + # Match standalone filenames with extensions + r'\b([a-zA-Z0-9_-]+\.(py|js|java|c|cpp|h|hpp|cs|go|rb|php|html|css|ts|jsx|tsx|vue|sh|sql|json|xml|yaml|yml|md|rs|swift|kt|scala|pl|r|lua))\b', + # Match file paths in comments + r'#\s*([a-zA-Z0-9_/-]+\.[a-zA-Z0-9]+)' + ] + + # Check each context line for file name patterns + for line in context_lines: + line = line.strip() + for pattern in file_patterns: + matches = re.findall(pattern, line) + if matches: + for match in matches: + if isinstance(match, tuple): + # If the match is a tuple (filename, extension) + file_name = match[0] + file_type = match[1] + # Verify it's not a code reference like Direction.DOWN + if not any(keyword in file_name for keyword in ['class.', 'enum.', 'import.']): + return os.path.splitext(file_name)[0], file_type + else: + # If the match is a string (full filename) + file_name = match + file_type = os.path.splitext(file_name)[1].lstrip('.') + # Verify it's not a code reference + if file_type in COMMON_CODE_EXTENSIONS and not any( + keyword in file_name for keyword in ['class.', 'enum.', 'import.']): + return os.path.splitext(file_name)[0], file_type + + # If no explicit file name found, use LLM to infer from code content + # Extract the code content + code_content = block_with_context + + # Get the first 20 lines of code for LLM analysis + code_lines = code_content.split('\n') + code_preview = '\n'.join(code_lines[:20]) + + # Get the model to use + model_to_use = None + if hasattr(self, 'context') and self.context: + if hasattr(self.context, 'model') and self.context.model: + model_to_use = self.context.model + elif hasattr(self.context, 'team_context') and self.context.team_context: + if hasattr(self.context.team_context, 'model') and self.context.team_context.model: + model_to_use = self.context.team_context.model + + # If no model is available in context, use the tool's model + if not model_to_use and hasattr(self, 'model') and self.model: + model_to_use = self.model + + if model_to_use: + # Prepare a prompt for the model + prompt = f"""Analyze the following code and determine the most appropriate file name and file type/extension. +The file name should be descriptive but concise, using snake_case (lowercase with underscores). +The file type should be a standard file extension (e.g., py, js, html, css, java). + +Code preview (first 20 lines): +{code_preview} + +Return your answer in JSON format with these fields: +- file_name: The suggested file name (without extension) +- file_type: The suggested file extension + +JSON response:""" + + # Create a request to the model + request = LLMRequest( + messages=[{"role": "user", "content": prompt}], + temperature=0, + json_format=True + ) + + try: + response = model_to_use.call(request) + + if not response.is_error: + # Clean the JSON response + json_content = self._clean_json_response(response.data["choices"][0]["message"]["content"]) + result = json.loads(json_content) + + file_name = result.get("file_name", "") + file_type = result.get("file_type", "") + + if file_name and file_type: + return file_name, file_type + except Exception as e: + logger.error(f"Error using model to determine file name: {str(e)}") + + # If we still don't have a file name, use the language as file type + if language and language in COMMON_CODE_EXTENSIONS: + timestamp = int(time.time()) + return f"code_{timestamp}", language + + # If all else fails, return empty strings + return "", "" + + def _clean_json_response(self, text: str) -> str: + """ + Clean JSON response from LLM by removing markdown code block markers. + + :param text: The text containing JSON possibly wrapped in markdown code blocks + :return: Clean JSON string + """ + # Remove markdown code block markers if present + if text.startswith("```json"): + text = text[7:] + elif text.startswith("```"): + # Find the first newline to skip the language identifier line + first_newline = text.find('\n') + if first_newline != -1: + text = text[first_newline + 1:] + + if text.endswith("```"): + text = text[:-3] + + return text.strip() + + def _clean_code_block(self, block_with_context: str) -> str: + """ + Clean a code block by removing markdown code markers and context lines. + + :param block_with_context: Code block with context lines + :return: Clean code ready for execution + """ + # Check if this is a full HTML or XML document + if block_with_context.strip().startswith((" 500 else ""} + + Respond in JSON format only with the following structure: + {{ + "is_code": true/false, # Whether this is primarily code implementation + "filename": "suggested_filename", # Don't include extension, english words + "extension": "appropriate_extension" # Don't include the dot, e.g., "md", "py", "js" + }} + """ + + try: + # Create a request to the model + request = LLMRequest( + messages=[{"role": "user", "content": prompt}], + temperature=0.1, + json_format=True + ) + + # Call the model using the standard interface + response = model.call(request) + + if response.is_error: + logger.warning(f"Error from model: {response.error_message}") + raise Exception(f"Model error: {response.error_message}") + + # Extract JSON from response + result = response.data["choices"][0]["message"]["content"] + + # Clean the JSON response + result = self._clean_json_response(result) + + # Parse the JSON + params = json.loads(result) + + # For backward compatibility, return tuple format + file_name = params.get("filename", "output") + # Remove dot from extension if present + file_type = params.get("extension", "md").lstrip(".") + extract_code = params.get("is_code", False) + + return file_name, file_type, extract_code + except Exception as e: + logger.warning(f"Error getting file parameters from model: {e}") + # Default fallback + return "output", "md", False + + def _get_team_name_from_context(self) -> Optional[str]: + """ + Get team name from the agent's context. + + :return: Team name or None if not found + """ + if hasattr(self, 'context') and self.context: + # Try to get team name from team_context + if hasattr(self.context, 'team_context') and self.context.team_context: + return self.context.team_context.name + + # Try direct team_name attribute + if hasattr(self.context, 'name'): + return self.context.name + + return None + + def _get_task_id_from_context(self) -> Optional[str]: + """ + Get task ID from the agent's context. + + :return: Task ID or None if not found + """ + if hasattr(self, 'context') and self.context: + # Try to get task ID from task object + if hasattr(self.context, 'task') and self.context.task: + return self.context.task.id + + # Try team_context's task + if hasattr(self.context, 'team_context') and self.context.team_context: + if hasattr(self.context.team_context, 'task') and self.context.team_context.task: + return self.context.team_context.task.id + + return None + + def _get_task_dir_from_context(self) -> Optional[str]: + """ + Get task directory name from the team context. + + :return: Task directory name or None if not found + """ + if hasattr(self, 'context') and self.context: + # Try to get from team_context + if hasattr(self.context, 'team_context') and self.context.team_context: + if hasattr(self.context.team_context, 'task_short_name') and self.context.team_context.task_short_name: + return self.context.team_context.task_short_name + + # Fall back to task ID if available + return self._get_task_id_from_context() + + def _extract_content_from_context(self) -> str: + """ + Extract content from the agent's context. + + :return: Extracted content + """ + # Check if we have access to the agent's context + if not hasattr(self, 'context') or not self.context: + return "" + + # Try to get the most recent final answer from the agent + if hasattr(self.context, 'final_answer') and self.context.final_answer: + return self.context.final_answer + + # Try to get the most recent final answer from team context + if hasattr(self.context, 'team_context') and self.context.team_context: + if hasattr(self.context.team_context, 'agent_outputs') and self.context.team_context.agent_outputs: + latest_output = self.context.team_context.agent_outputs[-1].output + return latest_output + + # If we have action history, try to get the most recent final answer + if hasattr(self.context, 'action_history') and self.context.action_history: + for action in reversed(self.context.action_history): + if "final_answer" in action and action["final_answer"]: + return action["final_answer"] + + return "" + + def _extract_code_blocks(self, content: str) -> str: + """ + Extract code blocks from markdown content. + + :param content: The content to extract code blocks from + :return: Extracted code blocks + """ + # Pattern to match markdown code blocks + code_block_pattern = r'```(?:\w+)?\n([\s\S]*?)\n```' + + # Find all code blocks + code_blocks = re.findall(code_block_pattern, content) + + if code_blocks: + # Join all code blocks with newlines + return '\n\n'.join(code_blocks) + + return content # Return original content if no code blocks found + + def _infer_file_name(self, content: str) -> str: + """ + Infer a file name from the content. + + :param content: The content to analyze. + :return: A suggested file name. + """ + # Check for title patterns in markdown + title_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE) + if title_match: + # Convert title to a valid filename + title = title_match.group(1).strip() + return self._sanitize_filename(title) + + # Check for class/function definitions in code + code_match = re.search(r'(class|def|function)\s+(\w+)', content) + if code_match: + return self._sanitize_filename(code_match.group(2)) + + # Default name based on content type + if self._is_likely_code(content): + return "code" + elif self._is_likely_markdown(content): + return "document" + elif self._is_likely_json(content): + return "data" + else: + return "output" + + def _infer_file_type(self, content: str) -> str: + """ + Infer the file type/extension from the content. + + :param content: The content to analyze. + :return: A suggested file extension. + """ + # Check for common programming language patterns + if re.search(r'(import\s+[a-zA-Z0-9_]+|from\s+[a-zA-Z0-9_\.]+\s+import)', content): + return "py" # Python + elif re.search(r'(public\s+class|private\s+class|protected\s+class)', content): + return "java" # Java + elif re.search(r'(function\s+\w+\s*\(|const\s+\w+\s*=|let\s+\w+\s*=|var\s+\w+\s*=)', content): + return "js" # JavaScript + elif re.search(r'()', content): + return "html" # HTML + elif re.search(r'(#include\s+<\w+\.h>|int\s+main\s*\()', content): + return "cpp" # C/C++ + + # Check for markdown + if self._is_likely_markdown(content): + return "md" + + # Check for JSON + if self._is_likely_json(content): + return "json" + + # Default to text + return "txt" + + def _is_likely_code(self, content: str) -> bool: + """Check if the content is likely code.""" + # First check for common HTML/XML patterns + if content.strip().startswith((".*?)', # HTML/XML tags + r'(var|let|const)\s+\w+\s*=', # JavaScript variable declarations + r'#\s*\w+', # CSS ID selectors or Python comments + r'\.\w+\s*\{', # CSS class selectors + r'@media|@import|@font-face' # CSS at-rules + ] + return any(re.search(pattern, content) for pattern in code_patterns) + + def _is_likely_markdown(self, content: str) -> bool: + """Check if the content is likely markdown.""" + md_patterns = [ + r'^#\s+.+$', # Headers + r'^\*\s+.+$', # Unordered lists + r'^\d+\.\s+.+$', # Ordered lists + r'\[.+\]\(.+\)', # Links + r'!\[.+\]\(.+\)' # Images + ] + return any(re.search(pattern, content, re.MULTILINE) for pattern in md_patterns) + + def _is_likely_json(self, content: str) -> bool: + """Check if the content is likely JSON.""" + try: + content = content.strip() + if (content.startswith('{') and content.endswith('}')) or ( + content.startswith('[') and content.endswith(']')): + json.loads(content) + return True + except: + pass + return False + + def _sanitize_filename(self, name: str) -> str: + """ + Sanitize a string to be used as a filename. + + :param name: The string to sanitize. + :return: A sanitized filename. + """ + # Replace spaces with underscores + name = name.replace(' ', '_') + + # Remove invalid characters + name = re.sub(r'[^\w\-\.]', '', name) + + # Limit length + if len(name) > 50: + name = name[:50] + + return name.lower() + + def _process_file_path(self, file_path: str) -> Tuple[str, str]: + """ + Process a file path to extract the file name and type, and create directories if needed. + + :param file_path: The file path to process + :return: Tuple of (file_name, file_type) + """ + # Get the file name and extension + file_name = os.path.basename(file_path) + file_type = os.path.splitext(file_name)[1].lstrip('.') + + return os.path.splitext(file_name)[0], file_type diff --git a/agent/tools/find/__init__.py b/agent/tools/find/__init__.py new file mode 100644 index 0000000..f2af14f --- /dev/null +++ b/agent/tools/find/__init__.py @@ -0,0 +1,3 @@ +from .find import Find + +__all__ = ['Find'] diff --git a/agent/tools/find/find.py b/agent/tools/find/find.py new file mode 100644 index 0000000..7a2c4a1 --- /dev/null +++ b/agent/tools/find/find.py @@ -0,0 +1,177 @@ +""" +Find tool - Search for files by glob pattern +""" + +import os +import glob as glob_module +from typing import Dict, Any, List + +from agent.tools.base_tool import BaseTool, ToolResult +from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES + + +DEFAULT_LIMIT = 1000 + + +class Find(BaseTool): + """Tool for finding files by pattern""" + + name: str = "find" + description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)." + + params: dict = { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'" + }, + "path": { + "type": "string", + "description": "Directory to search in (default: current directory)" + }, + "limit": { + "type": "integer", + "description": f"Maximum number of results (default: {DEFAULT_LIMIT})" + } + }, + "required": ["pattern"] + } + + def __init__(self, config: dict = None): + self.config = config or {} + self.cwd = self.config.get("cwd", os.getcwd()) + + def execute(self, args: Dict[str, Any]) -> ToolResult: + """ + Execute file search + + :param args: Search parameters + :return: Search results or error + """ + pattern = args.get("pattern", "").strip() + search_path = args.get("path", ".").strip() + limit = args.get("limit", DEFAULT_LIMIT) + + if not pattern: + return ToolResult.fail("Error: pattern parameter is required") + + # Resolve search path + absolute_path = self._resolve_path(search_path) + + if not os.path.exists(absolute_path): + return ToolResult.fail(f"Error: Path not found: {search_path}") + + if not os.path.isdir(absolute_path): + return ToolResult.fail(f"Error: Not a directory: {search_path}") + + try: + # Load .gitignore patterns + ignore_patterns = self._load_gitignore(absolute_path) + + # Search for files + results = [] + search_pattern = os.path.join(absolute_path, pattern) + + # Use glob with recursive support + for file_path in glob_module.glob(search_pattern, recursive=True): + # Skip if matches ignore patterns + if self._should_ignore(file_path, absolute_path, ignore_patterns): + continue + + # Get relative path + relative_path = os.path.relpath(file_path, absolute_path) + + # Add trailing slash for directories + if os.path.isdir(file_path): + relative_path += '/' + + results.append(relative_path) + + if len(results) >= limit: + break + + if not results: + return ToolResult.success({"message": "No files found matching pattern", "files": []}) + + # Sort results + results.sort() + + # Format output + raw_output = '\n'.join(results) + truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes + + output = truncation.content + details = {} + notices = [] + + result_limit_reached = len(results) >= limit + if result_limit_reached: + notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern") + details["result_limit_reached"] = limit + + if truncation.truncated: + notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached") + details["truncation"] = truncation.to_dict() + + if notices: + output += f"\n\n[{'. '.join(notices)}]" + + return ToolResult.success({ + "output": output, + "file_count": len(results), + "details": details if details else None + }) + + except Exception as e: + return ToolResult.fail(f"Error executing find: {str(e)}") + + def _resolve_path(self, path: str) -> str: + """Resolve path to absolute path""" + # Expand ~ to user home directory + path = os.path.expanduser(path) + if os.path.isabs(path): + return path + return os.path.abspath(os.path.join(self.cwd, path)) + + def _load_gitignore(self, directory: str) -> List[str]: + """Load .gitignore patterns from directory""" + patterns = [] + gitignore_path = os.path.join(directory, '.gitignore') + + if os.path.exists(gitignore_path): + try: + with open(gitignore_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line and not line.startswith('#'): + patterns.append(line) + except: + pass + + # Add common ignore patterns + patterns.extend([ + '.git', + '__pycache__', + '*.pyc', + 'node_modules', + '.DS_Store' + ]) + + return patterns + + def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool: + """Check if file should be ignored based on patterns""" + relative_path = os.path.relpath(file_path, base_path) + + for pattern in patterns: + # Simple pattern matching + if pattern in relative_path: + return True + + # Check if it's a directory pattern + if pattern.endswith('/'): + if relative_path.startswith(pattern.rstrip('/')): + return True + + return False diff --git a/agent/tools/google_search/google_search.py b/agent/tools/google_search/google_search.py new file mode 100644 index 0000000..f8005e8 --- /dev/null +++ b/agent/tools/google_search/google_search.py @@ -0,0 +1,48 @@ +import requests + +from agent.tools.base_tool import BaseTool, ToolResult + + +class GoogleSearch(BaseTool): + name: str = "google_search" + description: str = "A tool to perform Google searches using the Serper API." + params: dict = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to perform." + } + }, + "required": ["query"] + } + config: dict = {} + + def __init__(self, config=None): + self.config = config or {} + + def execute(self, args: dict) -> ToolResult: + api_key = self.config.get("api_key") # Replace with your actual API key + url = "https://google.serper.dev/search" + headers = { + "X-API-KEY": api_key, + "Content-Type": "application/json" + } + data = { + "q": args.get("query"), + "k": 10 + } + + response = requests.post(url, headers=headers, json=data) + result = response.json() + + if result.get("statusCode") and result.get("statusCode") == 503: + return ToolResult.fail(result=result) + else: + # Check if the returned result contains the 'organic' key and ensure it is a list + if 'organic' in result and isinstance(result.get('organic'), list): + result_data = result['organic'] + else: + # If there are no organic results, return the full response or an empty list + result_data = result.get('organic', []) if isinstance(result.get('organic'), list) else [] + return ToolResult.success(result=result_data) diff --git a/agent/tools/grep/__init__.py b/agent/tools/grep/__init__.py new file mode 100644 index 0000000..e4d57b0 --- /dev/null +++ b/agent/tools/grep/__init__.py @@ -0,0 +1,3 @@ +from .grep import Grep + +__all__ = ['Grep'] diff --git a/agent/tools/grep/grep.py b/agent/tools/grep/grep.py new file mode 100644 index 0000000..1e7d95e --- /dev/null +++ b/agent/tools/grep/grep.py @@ -0,0 +1,248 @@ +""" +Grep tool - Search file contents for patterns +Uses ripgrep (rg) for fast searching +""" + +import os +import re +import subprocess +import json +from typing import Dict, Any, List, Optional + +from agent.tools.base_tool import BaseTool, ToolResult +from agent.tools.utils.truncate import ( + truncate_head, truncate_line, format_size, + DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH +) + + +DEFAULT_LIMIT = 100 + + +class Grep(BaseTool): + """Tool for searching file contents""" + + name: str = "grep" + description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars." + + params: dict = { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Search pattern (regex or literal string)" + }, + "path": { + "type": "string", + "description": "Directory or file to search (default: current directory)" + }, + "glob": { + "type": "string", + "description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'" + }, + "ignoreCase": { + "type": "boolean", + "description": "Case-insensitive search (default: false)" + }, + "literal": { + "type": "boolean", + "description": "Treat pattern as literal string instead of regex (default: false)" + }, + "context": { + "type": "integer", + "description": "Number of lines to show before and after each match (default: 0)" + }, + "limit": { + "type": "integer", + "description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})" + } + }, + "required": ["pattern"] + } + + def __init__(self, config: dict = None): + self.config = config or {} + self.cwd = self.config.get("cwd", os.getcwd()) + self.rg_path = self._find_ripgrep() + + def _find_ripgrep(self) -> Optional[str]: + """Find ripgrep executable""" + try: + result = subprocess.run(['which', 'rg'], capture_output=True, text=True) + if result.returncode == 0: + return result.stdout.strip() + except: + pass + return None + + def execute(self, args: Dict[str, Any]) -> ToolResult: + """ + Execute grep search + + :param args: Search parameters + :return: Search results or error + """ + if not self.rg_path: + return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.") + + pattern = args.get("pattern", "").strip() + search_path = args.get("path", ".").strip() + glob = args.get("glob") + ignore_case = args.get("ignoreCase", False) + literal = args.get("literal", False) + context = args.get("context", 0) + limit = args.get("limit", DEFAULT_LIMIT) + + if not pattern: + return ToolResult.fail("Error: pattern parameter is required") + + # Resolve search path + absolute_path = self._resolve_path(search_path) + + if not os.path.exists(absolute_path): + return ToolResult.fail(f"Error: Path not found: {search_path}") + + # Build ripgrep command + cmd = [ + self.rg_path, + '--json', + '--line-number', + '--color=never', + '--hidden' + ] + + if ignore_case: + cmd.append('--ignore-case') + + if literal: + cmd.append('--fixed-strings') + + if glob: + cmd.extend(['--glob', glob]) + + cmd.extend([pattern, absolute_path]) + + try: + # Execute ripgrep + result = subprocess.run( + cmd, + cwd=self.cwd, + capture_output=True, + text=True, + timeout=30 + ) + + # Parse JSON output + matches = [] + match_count = 0 + + for line in result.stdout.splitlines(): + if not line.strip(): + continue + + try: + event = json.loads(line) + if event.get('type') == 'match': + data = event.get('data', {}) + file_path = data.get('path', {}).get('text') + line_number = data.get('line_number') + + if file_path and line_number: + matches.append({ + 'file': file_path, + 'line': line_number + }) + match_count += 1 + + if match_count >= limit: + break + except json.JSONDecodeError: + continue + + if match_count == 0: + return ToolResult.success({"message": "No matches found", "matches": []}) + + # Format output with context + output_lines = [] + lines_truncated = False + is_directory = os.path.isdir(absolute_path) + + for match in matches: + file_path = match['file'] + line_number = match['line'] + + # Format file path + if is_directory: + relative_path = os.path.relpath(file_path, absolute_path) + else: + relative_path = os.path.basename(file_path) + + # Read file and get context + try: + with open(file_path, 'r', encoding='utf-8') as f: + file_lines = f.read().split('\n') + + # Calculate context range + start = max(0, line_number - 1 - context) if context > 0 else line_number - 1 + end = min(len(file_lines), line_number + context) if context > 0 else line_number + + # Format lines with context + for i in range(start, end): + line_text = file_lines[i].replace('\r', '') + + # Truncate long lines + truncated_text, was_truncated = truncate_line(line_text) + if was_truncated: + lines_truncated = True + + # Format output + current_line = i + 1 + if current_line == line_number: + output_lines.append(f"{relative_path}:{current_line}: {truncated_text}") + else: + output_lines.append(f"{relative_path}-{current_line}- {truncated_text}") + + except Exception: + output_lines.append(f"{relative_path}:{line_number}: (unable to read file)") + + # Apply byte truncation + raw_output = '\n'.join(output_lines) + truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes + + output = truncation.content + details = {} + notices = [] + + if match_count >= limit: + notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern") + details["match_limit_reached"] = limit + + if truncation.truncated: + notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached") + details["truncation"] = truncation.to_dict() + + if lines_truncated: + notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines") + details["lines_truncated"] = True + + if notices: + output += f"\n\n[{'. '.join(notices)}]" + + return ToolResult.success({ + "output": output, + "match_count": match_count, + "details": details if details else None + }) + + except subprocess.TimeoutExpired: + return ToolResult.fail("Error: Search timed out after 30 seconds") + except Exception as e: + return ToolResult.fail(f"Error executing grep: {str(e)}") + + def _resolve_path(self, path: str) -> str: + """Resolve path to absolute path""" + # Expand ~ to user home directory + path = os.path.expanduser(path) + if os.path.isabs(path): + return path + return os.path.abspath(os.path.join(self.cwd, path)) diff --git a/agent/tools/ls/__init__.py b/agent/tools/ls/__init__.py new file mode 100644 index 0000000..ad7aaa0 --- /dev/null +++ b/agent/tools/ls/__init__.py @@ -0,0 +1,3 @@ +from .ls import Ls + +__all__ = ['Ls'] diff --git a/agent/tools/ls/ls.py b/agent/tools/ls/ls.py new file mode 100644 index 0000000..0bd4f01 --- /dev/null +++ b/agent/tools/ls/ls.py @@ -0,0 +1,125 @@ +""" +Ls tool - List directory contents +""" + +import os +from typing import Dict, Any + +from agent.tools.base_tool import BaseTool, ToolResult +from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES + + +DEFAULT_LIMIT = 500 + + +class Ls(BaseTool): + """Tool for listing directory contents""" + + name: str = "ls" + description: str = f"List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to {DEFAULT_LIMIT} entries or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)." + + params: dict = { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory to list (default: current directory)" + }, + "limit": { + "type": "integer", + "description": f"Maximum number of entries to return (default: {DEFAULT_LIMIT})" + } + }, + "required": [] + } + + def __init__(self, config: dict = None): + self.config = config or {} + self.cwd = self.config.get("cwd", os.getcwd()) + + def execute(self, args: Dict[str, Any]) -> ToolResult: + """ + Execute directory listing + + :param args: Listing parameters + :return: Directory contents or error + """ + path = args.get("path", ".").strip() + limit = args.get("limit", DEFAULT_LIMIT) + + # Resolve path + absolute_path = self._resolve_path(path) + + if not os.path.exists(absolute_path): + return ToolResult.fail(f"Error: Path not found: {path}") + + if not os.path.isdir(absolute_path): + return ToolResult.fail(f"Error: Not a directory: {path}") + + try: + # Read directory entries + entries = os.listdir(absolute_path) + + # Sort alphabetically (case-insensitive) + entries.sort(key=lambda x: x.lower()) + + # Format entries with directory indicators + results = [] + entry_limit_reached = False + + for entry in entries: + if len(results) >= limit: + entry_limit_reached = True + break + + full_path = os.path.join(absolute_path, entry) + + try: + if os.path.isdir(full_path): + results.append(entry + '/') + else: + results.append(entry) + except: + # Skip entries we can't stat + continue + + if not results: + return ToolResult.success({"message": "(empty directory)", "entries": []}) + + # Format output + raw_output = '\n'.join(results) + truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes + + output = truncation.content + details = {} + notices = [] + + if entry_limit_reached: + notices.append(f"{limit} entries limit reached. Use limit={limit * 2} for more") + details["entry_limit_reached"] = limit + + if truncation.truncated: + notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached") + details["truncation"] = truncation.to_dict() + + if notices: + output += f"\n\n[{'. '.join(notices)}]" + + return ToolResult.success({ + "output": output, + "entry_count": len(results), + "details": details if details else None + }) + + except PermissionError: + return ToolResult.fail(f"Error: Permission denied reading directory: {path}") + except Exception as e: + return ToolResult.fail(f"Error listing directory: {str(e)}") + + def _resolve_path(self, path: str) -> str: + """Resolve path to absolute path""" + # Expand ~ to user home directory + path = os.path.expanduser(path) + if os.path.isabs(path): + return path + return os.path.abspath(os.path.join(self.cwd, path)) diff --git a/agent/tools/memory/__init__.py b/agent/tools/memory/__init__.py new file mode 100644 index 0000000..9a475bf --- /dev/null +++ b/agent/tools/memory/__init__.py @@ -0,0 +1,10 @@ +""" +Memory tools for Agent + +Provides memory_search and memory_get tools +""" + +from agent.tools.memory.memory_search import MemorySearchTool +from agent.tools.memory.memory_get import MemoryGetTool + +__all__ = ['MemorySearchTool', 'MemoryGetTool'] diff --git a/agent/tools/memory/memory_get.py b/agent/tools/memory/memory_get.py new file mode 100644 index 0000000..0ad1cbd --- /dev/null +++ b/agent/tools/memory/memory_get.py @@ -0,0 +1,107 @@ +""" +Memory get tool + +Allows agents to read specific sections from memory files +""" + +from typing import Dict, Any +from pathlib import Path +from agent.tools.base_tool import BaseTool + + +class MemoryGetTool(BaseTool): + """Tool for reading memory file contents""" + + name: str = "memory_get" + description: str = ( + "Read specific content from memory files. " + "Use this to get full context from a memory file or specific line range." + ) + params: dict = { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')" + }, + "start_line": { + "type": "integer", + "description": "Starting line number (optional, default: 1)", + "default": 1 + }, + "num_lines": { + "type": "integer", + "description": "Number of lines to read (optional, reads all if not specified)" + } + }, + "required": ["path"] + } + + def __init__(self, memory_manager): + """ + Initialize memory get tool + + Args: + memory_manager: MemoryManager instance + """ + super().__init__() + self.memory_manager = memory_manager + + def execute(self, args: dict): + """ + Execute memory file read + + Args: + args: Dictionary with path, start_line, num_lines + + Returns: + ToolResult with file content + """ + from agent.tools.base_tool import ToolResult + + path = args.get("path") + start_line = args.get("start_line", 1) + num_lines = args.get("num_lines") + + if not path: + return ToolResult.fail("Error: path parameter is required") + + try: + workspace_dir = self.memory_manager.config.get_workspace() + file_path = workspace_dir / path + + if not file_path.exists(): + return ToolResult.fail(f"Error: File not found: {path}") + + content = file_path.read_text() + lines = content.split('\n') + + # Handle line range + if start_line < 1: + start_line = 1 + + start_idx = start_line - 1 + + if num_lines: + end_idx = start_idx + num_lines + selected_lines = lines[start_idx:end_idx] + else: + selected_lines = lines[start_idx:] + + result = '\n'.join(selected_lines) + + # Add metadata + total_lines = len(lines) + shown_lines = len(selected_lines) + + output = [ + f"File: {path}", + f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})", + "", + result + ] + + return ToolResult.success('\n'.join(output)) + + except Exception as e: + return ToolResult.fail(f"Error reading memory file: {str(e)}") diff --git a/agent/tools/memory/memory_search.py b/agent/tools/memory/memory_search.py new file mode 100644 index 0000000..e854d02 --- /dev/null +++ b/agent/tools/memory/memory_search.py @@ -0,0 +1,96 @@ +""" +Memory search tool + +Allows agents to search their memory using semantic and keyword search +""" + +from typing import Dict, Any, Optional +from agent.tools.base_tool import BaseTool + + +class MemorySearchTool(BaseTool): + """Tool for searching agent memory""" + + name: str = "memory_search" + description: str = ( + "Search agent's long-term memory using semantic and keyword search. " + "Use this to recall past conversations, preferences, and knowledge." + ) + params: dict = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query (can be natural language question or keywords)" + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results to return (default: 10)", + "default": 10 + }, + "min_score": { + "type": "number", + "description": "Minimum relevance score (0-1, default: 0.3)", + "default": 0.3 + } + }, + "required": ["query"] + } + + def __init__(self, memory_manager, user_id: Optional[str] = None): + """ + Initialize memory search tool + + Args: + memory_manager: MemoryManager instance + user_id: Optional user ID for scoped search + """ + super().__init__() + self.memory_manager = memory_manager + self.user_id = user_id + + def execute(self, args: dict): + """ + Execute memory search + + Args: + args: Dictionary with query, max_results, min_score + + Returns: + ToolResult with formatted search results + """ + from agent.tools.base_tool import ToolResult + import asyncio + + query = args.get("query") + max_results = args.get("max_results", 10) + min_score = args.get("min_score", 0.3) + + if not query: + return ToolResult.fail("Error: query parameter is required") + + try: + # Run async search in sync context + results = asyncio.run(self.memory_manager.search( + query=query, + user_id=self.user_id, + max_results=max_results, + min_score=min_score, + include_shared=True + )) + + if not results: + return ToolResult.success(f"No relevant memories found for query: {query}") + + # Format results + output = [f"Found {len(results)} relevant memories:\n"] + + for i, result in enumerate(results, 1): + output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})") + output.append(f" Score: {result.score:.3f}") + output.append(f" Snippet: {result.snippet}") + + return ToolResult.success("\n".join(output)) + + except Exception as e: + return ToolResult.fail(f"Error searching memory: {str(e)}") diff --git a/agent/tools/read/__init__.py b/agent/tools/read/__init__.py new file mode 100644 index 0000000..4c974e9 --- /dev/null +++ b/agent/tools/read/__init__.py @@ -0,0 +1,3 @@ +from .read import Read + +__all__ = ['Read'] diff --git a/agent/tools/read/read.py b/agent/tools/read/read.py new file mode 100644 index 0000000..8fad254 --- /dev/null +++ b/agent/tools/read/read.py @@ -0,0 +1,336 @@ +""" +Read tool - Read file contents +Supports text files, images (jpg, png, gif, webp), and PDF files +""" + +import os +from typing import Dict, Any +from pathlib import Path + +from agent.tools.base_tool import BaseTool, ToolResult +from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES + + +class Read(BaseTool): + """Tool for reading file contents""" + + name: str = "read" + description: str = f"Read the contents of a file. Supports text files, PDF files, and images (jpg, png, gif, webp). For text files, output is truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Use offset/limit for large files." + + params: dict = { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file to read (relative or absolute)" + }, + "offset": { + "type": "integer", + "description": "Line number to start reading from (1-indexed, optional)" + }, + "limit": { + "type": "integer", + "description": "Maximum number of lines to read (optional)" + } + }, + "required": ["path"] + } + + def __init__(self, config: dict = None): + self.config = config or {} + self.cwd = self.config.get("cwd", os.getcwd()) + # Supported image formats + self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp'} + # Supported PDF format + self.pdf_extensions = {'.pdf'} + + def execute(self, args: Dict[str, Any]) -> ToolResult: + """ + Execute file read operation + + :param args: Contains file path and optional offset/limit parameters + :return: File content or error message + """ + path = args.get("path", "").strip() + offset = args.get("offset") + limit = args.get("limit") + + if not path: + return ToolResult.fail("Error: path parameter is required") + + # Resolve path + absolute_path = self._resolve_path(path) + + # Check if file exists + if not os.path.exists(absolute_path): + return ToolResult.fail(f"Error: File not found: {path}") + + # Check if readable + if not os.access(absolute_path, os.R_OK): + return ToolResult.fail(f"Error: File is not readable: {path}") + + # Check file type + file_ext = Path(absolute_path).suffix.lower() + + # Check if image + if file_ext in self.image_extensions: + return self._read_image(absolute_path, file_ext) + + # Check if PDF + if file_ext in self.pdf_extensions: + return self._read_pdf(absolute_path, path, offset, limit) + + # Read text file + return self._read_text(absolute_path, path, offset, limit) + + def _resolve_path(self, path: str) -> str: + """ + Resolve path to absolute path + + :param path: Relative or absolute path + :return: Absolute path + """ + # Expand ~ to user home directory + path = os.path.expanduser(path) + if os.path.isabs(path): + return path + return os.path.abspath(os.path.join(self.cwd, path)) + + def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult: + """ + Read image file + + :param absolute_path: Absolute path to the image file + :param file_ext: File extension + :return: Result containing image information + """ + try: + # Read image file + with open(absolute_path, 'rb') as f: + image_data = f.read() + + # Get file size + file_size = len(image_data) + + # Return image information (actual image data can be base64 encoded when needed) + import base64 + base64_data = base64.b64encode(image_data).decode('utf-8') + + # Determine MIME type + mime_type_map = { + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.gif': 'image/gif', + '.webp': 'image/webp' + } + mime_type = mime_type_map.get(file_ext, 'image/jpeg') + + result = { + "type": "image", + "mime_type": mime_type, + "size": file_size, + "size_formatted": format_size(file_size), + "data": base64_data # Base64 encoded image data + } + + return ToolResult.success(result) + + except Exception as e: + return ToolResult.fail(f"Error reading image file: {str(e)}") + + def _read_text(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult: + """ + Read text file + + :param absolute_path: Absolute path to the file + :param display_path: Path to display + :param offset: Starting line number (1-indexed) + :param limit: Maximum number of lines to read + :return: File content or error message + """ + try: + # Read file + with open(absolute_path, 'r', encoding='utf-8') as f: + content = f.read() + + all_lines = content.split('\n') + total_file_lines = len(all_lines) + + # Apply offset (if specified) + start_line = 0 + if offset is not None: + start_line = max(0, offset - 1) # Convert to 0-indexed + if start_line >= total_file_lines: + return ToolResult.fail( + f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)" + ) + + start_line_display = start_line + 1 # For display (1-indexed) + + # If user specified limit, use it + selected_content = content + user_limited_lines = None + if limit is not None: + end_line = min(start_line + limit, total_file_lines) + selected_content = '\n'.join(all_lines[start_line:end_line]) + user_limited_lines = end_line - start_line + elif offset is not None: + selected_content = '\n'.join(all_lines[start_line:]) + + # Apply truncation (considering line count and byte limits) + truncation = truncate_head(selected_content) + + output_text = "" + details = {} + + if truncation.first_line_exceeds_limit: + # First line exceeds 30KB limit + first_line_size = format_size(len(all_lines[start_line].encode('utf-8'))) + output_text = f"[Line {start_line_display} is {first_line_size}, exceeds {format_size(DEFAULT_MAX_BYTES)} limit. Use bash tool to read: head -c {DEFAULT_MAX_BYTES} {display_path} | tail -n +{start_line_display}]" + details["truncation"] = truncation.to_dict() + elif truncation.truncated: + # Truncation occurred + end_line_display = start_line_display + truncation.output_lines - 1 + next_offset = end_line_display + 1 + + output_text = truncation.content + + if truncation.truncated_by == "lines": + output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines}. Use offset={next_offset} to continue.]" + else: + output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]" + + details["truncation"] = truncation.to_dict() + elif user_limited_lines is not None and start_line + user_limited_lines < total_file_lines: + # User specified limit, more content available, but no truncation + remaining = total_file_lines - (start_line + user_limited_lines) + next_offset = start_line + user_limited_lines + 1 + + output_text = truncation.content + output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]" + else: + # No truncation, no exceeding user limit + output_text = truncation.content + + result = { + "content": output_text, + "total_lines": total_file_lines, + "start_line": start_line_display, + "output_lines": truncation.output_lines + } + + if details: + result["details"] = details + + return ToolResult.success(result) + + except UnicodeDecodeError: + return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {display_path}") + except Exception as e: + return ToolResult.fail(f"Error reading file: {str(e)}") + + def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult: + """ + Read PDF file content + + :param absolute_path: Absolute path to the file + :param display_path: Path to display + :param offset: Starting line number (1-indexed) + :param limit: Maximum number of lines to read + :return: PDF text content or error message + """ + try: + # Try to import pypdf + try: + from pypdf import PdfReader + except ImportError: + return ToolResult.fail( + "Error: pypdf library not installed. Install with: pip install pypdf" + ) + + # Read PDF + reader = PdfReader(absolute_path) + total_pages = len(reader.pages) + + # Extract text from all pages + text_parts = [] + for page_num, page in enumerate(reader.pages, 1): + page_text = page.extract_text() + if page_text.strip(): + text_parts.append(f"--- Page {page_num} ---\n{page_text}") + + if not text_parts: + return ToolResult.success({ + "content": f"[PDF file with {total_pages} pages, but no text content could be extracted]", + "total_pages": total_pages, + "message": "PDF may contain only images or be encrypted" + }) + + # Merge all text + full_content = "\n\n".join(text_parts) + all_lines = full_content.split('\n') + total_lines = len(all_lines) + + # Apply offset and limit (same logic as text files) + start_line = 0 + if offset is not None: + start_line = max(0, offset - 1) + if start_line >= total_lines: + return ToolResult.fail( + f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)" + ) + + start_line_display = start_line + 1 + + selected_content = full_content + user_limited_lines = None + if limit is not None: + end_line = min(start_line + limit, total_lines) + selected_content = '\n'.join(all_lines[start_line:end_line]) + user_limited_lines = end_line - start_line + elif offset is not None: + selected_content = '\n'.join(all_lines[start_line:]) + + # Apply truncation + truncation = truncate_head(selected_content) + + output_text = "" + details = {} + + if truncation.truncated: + end_line_display = start_line_display + truncation.output_lines - 1 + next_offset = end_line_display + 1 + + output_text = truncation.content + + if truncation.truncated_by == "lines": + output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]" + else: + output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]" + + details["truncation"] = truncation.to_dict() + elif user_limited_lines is not None and start_line + user_limited_lines < total_lines: + remaining = total_lines - (start_line + user_limited_lines) + next_offset = start_line + user_limited_lines + 1 + + output_text = truncation.content + output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]" + else: + output_text = truncation.content + + result = { + "content": output_text, + "total_pages": total_pages, + "total_lines": total_lines, + "start_line": start_line_display, + "output_lines": truncation.output_lines + } + + if details: + result["details"] = details + + return ToolResult.success(result) + + except Exception as e: + return ToolResult.fail(f"Error reading PDF file: {str(e)}") diff --git a/agent/tools/terminal/__init__.py b/agent/tools/terminal/__init__.py new file mode 100644 index 0000000..20f903d --- /dev/null +++ b/agent/tools/terminal/__init__.py @@ -0,0 +1,3 @@ +from .terminal import Terminal + +__all__ = ['Terminal'] \ No newline at end of file diff --git a/agent/tools/terminal/terminal.py b/agent/tools/terminal/terminal.py new file mode 100644 index 0000000..c7e161f --- /dev/null +++ b/agent/tools/terminal/terminal.py @@ -0,0 +1,100 @@ +import platform +import subprocess +from typing import Dict, Any + +from agent.tools.base_tool import BaseTool, ToolResult + + +class Terminal(BaseTool): + name: str = "terminal" + description: str = "A tool to run terminal commands on the local system" + params: dict = { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": f"The terminal command to execute which should be valid in {platform.system()} platform" + } + }, + "required": ["command"] + } + config: dict = {} + + def __init__(self, config=None): + self.config = config or {} + # Set of dangerous commands that should be blocked + self.command_ban_set = {"halt", "poweroff", "shutdown", "reboot", "rm", "kill", + "exit", "sudo", "su", "userdel", "groupdel", "logout", "alias"} + + def execute(self, args: Dict[str, Any]) -> ToolResult: + """ + Execute a terminal command safely. + + :param args: Dictionary containing the command to execute + :return: Result of the command execution + """ + command = args.get("command", "").strip() + + # Check if the command is safe to execute + if not self._is_safe_command(command): + return ToolResult.fail(result=f"Command '{command}' is not allowed for security reasons.") + + try: + result = subprocess.run( + command, + shell=True, + check=True, # Raise exception on non-zero return code + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=self.config.get("timeout", 30) + ) + + return ToolResult.success({ + "stdout": result.stdout, + "stderr": result.stderr, + "return_code": result.returncode, + "command": command + }) + except subprocess.CalledProcessError as e: + # Preserve the original error handling for CalledProcessError + return ToolResult.fail({ + "stdout": e.stdout, + "stderr": e.stderr, + "return_code": e.returncode, + "command": command + }) + except subprocess.TimeoutExpired: + return ToolResult.fail(result=f"Command timed out after {self.config.get('timeout', 20)} seconds.") + except Exception as e: + return ToolResult.fail(result=f"Error executing command: {str(e)}") + + def _is_safe_command(self, command: str) -> bool: + """ + Check if a command is safe to execute. + + :param command: The command to check + :return: True if the command is safe, False otherwise + """ + # Split the command to get the base command + cmd_parts = command.split() + if not cmd_parts: + return False + + base_cmd = cmd_parts[0].lower() + + # Check if the base command is in the ban list + if base_cmd in self.command_ban_set: + return False + + # Check for sudo/su commands + if any(banned in command.lower() for banned in ["sudo ", "su -"]): + return False + + # Check for rm -rf or similar dangerous patterns + if "rm" in base_cmd and ("-rf" in command or "-r" in command or "-f" in command): + return False + + # Additional security checks can be added here + + return True diff --git a/agent/tools/tool_manager.py b/agent/tools/tool_manager.py new file mode 100644 index 0000000..fa842fe --- /dev/null +++ b/agent/tools/tool_manager.py @@ -0,0 +1,208 @@ +import importlib +import importlib.util +from pathlib import Path +from typing import Dict, Any, Type +from agent.tools.base_tool import BaseTool +from common.log import logger + + +class ToolManager: + """ + Tool manager for managing tools. + """ + _instance = None + + def __new__(cls): + """Singleton pattern to ensure only one instance of ToolManager exists.""" + if cls._instance is None: + cls._instance = super(ToolManager, cls).__new__(cls) + cls._instance.tool_classes = {} # Store tool classes instead of instances + cls._instance._initialized = False + return cls._instance + + def __init__(self): + # Initialize only once + if not hasattr(self, 'tool_classes'): + self.tool_classes = {} # Dictionary to store tool classes + + def load_tools(self, tools_dir: str = "", config_dict=None): + """ + Load tools from both directory and configuration. + + :param tools_dir: Directory to scan for tool modules + """ + if tools_dir: + self._load_tools_from_directory(tools_dir) + self._configure_tools_from_config() + else: + self._load_tools_from_init() + self._configure_tools_from_config(config_dict) + + def _load_tools_from_init(self) -> bool: + """ + Load tool classes from tools.__init__.__all__ + + :return: True if tools were loaded, False otherwise + """ + try: + # Try to import the tools package + tools_package = importlib.import_module("agent.tools") + + # Check if __all__ is defined + if hasattr(tools_package, "__all__"): + tool_classes = tools_package.__all__ + + # Import each tool class directly from the tools package + for class_name in tool_classes: + try: + # Skip base classes + if class_name in ["BaseTool", "ToolManager"]: + continue + + # Get the class directly from the tools package + if hasattr(tools_package, class_name): + cls = getattr(tools_package, class_name) + + if ( + isinstance(cls, type) + and issubclass(cls, BaseTool) + and cls != BaseTool + ): + try: + # Create a temporary instance to get the name + temp_instance = cls() + tool_name = temp_instance.name + # Store the class, not the instance + self.tool_classes[tool_name] = cls + logger.debug(f"Loaded tool: {tool_name} from class {class_name}") + except ImportError as e: + # Ignore browser_use dependency missing errors + if "browser_use" in str(e): + pass + else: + logger.error(f"Error initializing tool class {cls.__name__}: {e}") + except Exception as e: + logger.error(f"Error initializing tool class {cls.__name__}: {e}") + except Exception as e: + logger.error(f"Error importing class {class_name}: {e}") + + return len(self.tool_classes) > 0 + return False + except ImportError: + logger.warning("Could not import agent.tools package") + return False + except Exception as e: + logger.error(f"Error loading tools from __init__.__all__: {e}") + return False + + def _load_tools_from_directory(self, tools_dir: str): + """Dynamically load tool classes from directory""" + tools_path = Path(tools_dir) + + # Traverse all .py files + for py_file in tools_path.rglob("*.py"): + # Skip initialization files and base tool files + if py_file.name in ["__init__.py", "base_tool.py", "tool_manager.py"]: + continue + + # Get module name + module_name = py_file.stem + + try: + # Load module directly from file + spec = importlib.util.spec_from_file_location(module_name, py_file) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find tool classes in the module + for attr_name in dir(module): + cls = getattr(module, attr_name) + if ( + isinstance(cls, type) + and issubclass(cls, BaseTool) + and cls != BaseTool + ): + try: + # Create a temporary instance to get the name + temp_instance = cls() + tool_name = temp_instance.name + # Store the class, not the instance + self.tool_classes[tool_name] = cls + except ImportError as e: + # Ignore browser_use dependency missing errors + if "browser_use" in str(e): + pass + else: + print(f"Error initializing tool class {cls.__name__}: {e}") + except Exception as e: + print(f"Error initializing tool class {cls.__name__}: {e}") + except Exception as e: + print(f"Error importing module {py_file}: {e}") + + def _configure_tools_from_config(self, config_dict=None): + """Configure tool classes based on configuration file""" + try: + # Get tools configuration + tools_config = config_dict or config().get("tools", {}) + + # Record tools that are configured but not loaded + missing_tools = [] + + # Store configurations for later use when instantiating + self.tool_configs = tools_config + + # Check which configured tools are missing + for tool_name in tools_config: + if tool_name not in self.tool_classes: + missing_tools.append(tool_name) + + # If there are missing tools, record warnings + if missing_tools: + for tool_name in missing_tools: + if tool_name == "browser": + logger.error( + "Browser tool is configured but could not be loaded. " + "Please install the required dependency with: " + "pip install browser-use>=0.1.40 or pip install agentmesh-sdk[full]" + ) + else: + logger.warning(f"Tool '{tool_name}' is configured but could not be loaded.") + + except Exception as e: + logger.error(f"Error configuring tools from config: {e}") + + def create_tool(self, name: str) -> BaseTool: + """ + Get a new instance of a tool by name. + + :param name: The name of the tool to get. + :return: A new instance of the tool or None if not found. + """ + tool_class = self.tool_classes.get(name) + if tool_class: + # Create a new instance + tool_instance = tool_class() + + # Apply configuration if available + if hasattr(self, 'tool_configs') and name in self.tool_configs: + tool_instance.config = self.tool_configs[name] + + return tool_instance + return None + + def list_tools(self) -> dict: + """ + Get information about all loaded tools. + + :return: A dictionary with tool information. + """ + result = {} + for name, tool_class in self.tool_classes.items(): + # Create a temporary instance to get schema + temp_instance = tool_class() + result[name] = { + "description": temp_instance.description, + "parameters": temp_instance.get_json_schema() + } + return result diff --git a/agent/tools/utils/__init__.py b/agent/tools/utils/__init__.py new file mode 100644 index 0000000..56e9a97 --- /dev/null +++ b/agent/tools/utils/__init__.py @@ -0,0 +1,40 @@ +from .truncate import ( + truncate_head, + truncate_tail, + truncate_line, + format_size, + TruncationResult, + DEFAULT_MAX_LINES, + DEFAULT_MAX_BYTES, + GREP_MAX_LINE_LENGTH +) + +from .diff import ( + strip_bom, + detect_line_ending, + normalize_to_lf, + restore_line_endings, + normalize_for_fuzzy_match, + fuzzy_find_text, + generate_diff_string, + FuzzyMatchResult +) + +__all__ = [ + 'truncate_head', + 'truncate_tail', + 'truncate_line', + 'format_size', + 'TruncationResult', + 'DEFAULT_MAX_LINES', + 'DEFAULT_MAX_BYTES', + 'GREP_MAX_LINE_LENGTH', + 'strip_bom', + 'detect_line_ending', + 'normalize_to_lf', + 'restore_line_endings', + 'normalize_for_fuzzy_match', + 'fuzzy_find_text', + 'generate_diff_string', + 'FuzzyMatchResult' +] diff --git a/agent/tools/utils/diff.py b/agent/tools/utils/diff.py new file mode 100644 index 0000000..3801ffe --- /dev/null +++ b/agent/tools/utils/diff.py @@ -0,0 +1,167 @@ +""" +Diff tools for file editing +Provides fuzzy matching and diff generation functionality +""" + +import difflib +import re +from typing import Optional, Tuple + + +def strip_bom(text: str) -> Tuple[str, str]: + """ + Remove BOM (Byte Order Mark) + + :param text: Original text + :return: (BOM, text after removing BOM) + """ + if text.startswith('\ufeff'): + return '\ufeff', text[1:] + return '', text + + +def detect_line_ending(text: str) -> str: + """ + Detect line ending type + + :param text: Text content + :return: Line ending type ('\r\n' or '\n') + """ + if '\r\n' in text: + return '\r\n' + return '\n' + + +def normalize_to_lf(text: str) -> str: + """ + Normalize all line endings to LF (\n) + + :param text: Original text + :return: Normalized text + """ + return text.replace('\r\n', '\n').replace('\r', '\n') + + +def restore_line_endings(text: str, original_ending: str) -> str: + """ + Restore original line endings + + :param text: LF normalized text + :param original_ending: Original line ending + :return: Text with restored line endings + """ + if original_ending == '\r\n': + return text.replace('\n', '\r\n') + return text + + +def normalize_for_fuzzy_match(text: str) -> str: + """ + Normalize text for fuzzy matching + Remove excess whitespace but preserve basic structure + + :param text: Original text + :return: Normalized text + """ + # Compress multiple spaces to one + text = re.sub(r'[ \t]+', ' ', text) + # Remove trailing spaces + text = re.sub(r' +\n', '\n', text) + # Remove leading spaces (but preserve indentation structure, only remove excess) + lines = text.split('\n') + normalized_lines = [] + for line in lines: + # Preserve indentation but normalize to multiples of single spaces + stripped = line.lstrip() + if stripped: + indent_count = len(line) - len(stripped) + # Normalize indentation (convert tabs to spaces) + normalized_indent = ' ' * indent_count + normalized_lines.append(normalized_indent + stripped) + else: + normalized_lines.append('') + return '\n'.join(normalized_lines) + + +class FuzzyMatchResult: + """Fuzzy match result""" + + def __init__(self, found: bool, index: int = -1, match_length: int = 0, content_for_replacement: str = ""): + self.found = found + self.index = index + self.match_length = match_length + self.content_for_replacement = content_for_replacement + + +def fuzzy_find_text(content: str, old_text: str) -> FuzzyMatchResult: + """ + Find text in content, try exact match first, then fuzzy match + + :param content: Content to search in + :param old_text: Text to find + :return: Match result + """ + # First try exact match + index = content.find(old_text) + if index != -1: + return FuzzyMatchResult( + found=True, + index=index, + match_length=len(old_text), + content_for_replacement=content + ) + + # Try fuzzy match + fuzzy_content = normalize_for_fuzzy_match(content) + fuzzy_old_text = normalize_for_fuzzy_match(old_text) + + index = fuzzy_content.find(fuzzy_old_text) + if index != -1: + # Fuzzy match successful, use normalized content for replacement + return FuzzyMatchResult( + found=True, + index=index, + match_length=len(fuzzy_old_text), + content_for_replacement=fuzzy_content + ) + + # Not found + return FuzzyMatchResult(found=False) + + +def generate_diff_string(old_content: str, new_content: str) -> dict: + """ + Generate unified diff string + + :param old_content: Old content + :param new_content: New content + :return: Dictionary containing diff and first changed line number + """ + old_lines = old_content.split('\n') + new_lines = new_content.split('\n') + + # Generate unified diff + diff_lines = list(difflib.unified_diff( + old_lines, + new_lines, + lineterm='', + fromfile='original', + tofile='modified' + )) + + # Find first changed line number + first_changed_line = None + for line in diff_lines: + if line.startswith('@@'): + # Parse @@ -1,3 +1,3 @@ format + match = re.search(r'@@ -\d+,?\d* \+(\d+)', line) + if match: + first_changed_line = int(match.group(1)) + break + + diff_string = '\n'.join(diff_lines) + + return { + 'diff': diff_string, + 'first_changed_line': first_changed_line + } diff --git a/agent/tools/utils/truncate.py b/agent/tools/utils/truncate.py new file mode 100644 index 0000000..1b0c1e0 --- /dev/null +++ b/agent/tools/utils/truncate.py @@ -0,0 +1,292 @@ +""" +Shared truncation utilities for tool outputs. + +Truncation is based on two independent limits - whichever is hit first wins: +- Line limit (default: 2000 lines) +- Byte limit (default: 50KB) + +Never returns partial lines (except bash tail truncation edge case). +""" + +from typing import Dict, Any, Optional, Literal + + +DEFAULT_MAX_LINES = 2000 +DEFAULT_MAX_BYTES = 50 * 1024 # 50KB +GREP_MAX_LINE_LENGTH = 500 # Max chars per grep match line + + +class TruncationResult: + """Truncation result""" + + def __init__( + self, + content: str, + truncated: bool, + truncated_by: Optional[Literal["lines", "bytes"]], + total_lines: int, + total_bytes: int, + output_lines: int, + output_bytes: int, + last_line_partial: bool = False, + first_line_exceeds_limit: bool = False, + max_lines: int = DEFAULT_MAX_LINES, + max_bytes: int = DEFAULT_MAX_BYTES + ): + self.content = content + self.truncated = truncated + self.truncated_by = truncated_by + self.total_lines = total_lines + self.total_bytes = total_bytes + self.output_lines = output_lines + self.output_bytes = output_bytes + self.last_line_partial = last_line_partial + self.first_line_exceeds_limit = first_line_exceeds_limit + self.max_lines = max_lines + self.max_bytes = max_bytes + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + "content": self.content, + "truncated": self.truncated, + "truncated_by": self.truncated_by, + "total_lines": self.total_lines, + "total_bytes": self.total_bytes, + "output_lines": self.output_lines, + "output_bytes": self.output_bytes, + "last_line_partial": self.last_line_partial, + "first_line_exceeds_limit": self.first_line_exceeds_limit, + "max_lines": self.max_lines, + "max_bytes": self.max_bytes + } + + +def format_size(bytes_count: int) -> str: + """Format bytes as human-readable size""" + if bytes_count < 1024: + return f"{bytes_count}B" + elif bytes_count < 1024 * 1024: + return f"{bytes_count / 1024:.1f}KB" + else: + return f"{bytes_count / (1024 * 1024):.1f}MB" + + +def truncate_head(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult: + """ + Truncate content from the head (keep first N lines/bytes). + Suitable for file reads where you want to see the beginning. + + Never returns partial lines. If first line exceeds byte limit, + returns empty content with first_line_exceeds_limit=True. + + :param content: Content to truncate + :param max_lines: Maximum number of lines (default: 2000) + :param max_bytes: Maximum number of bytes (default: 50KB) + :return: Truncation result + """ + if max_lines is None: + max_lines = DEFAULT_MAX_LINES + if max_bytes is None: + max_bytes = DEFAULT_MAX_BYTES + + total_bytes = len(content.encode('utf-8')) + lines = content.split('\n') + total_lines = len(lines) + + # Check if no truncation is needed + if total_lines <= max_lines and total_bytes <= max_bytes: + return TruncationResult( + content=content, + truncated=False, + truncated_by=None, + total_lines=total_lines, + total_bytes=total_bytes, + output_lines=total_lines, + output_bytes=total_bytes, + last_line_partial=False, + first_line_exceeds_limit=False, + max_lines=max_lines, + max_bytes=max_bytes + ) + + # Check if first line alone exceeds byte limit + first_line_bytes = len(lines[0].encode('utf-8')) + if first_line_bytes > max_bytes: + return TruncationResult( + content="", + truncated=True, + truncated_by="bytes", + total_lines=total_lines, + total_bytes=total_bytes, + output_lines=0, + output_bytes=0, + last_line_partial=False, + first_line_exceeds_limit=True, + max_lines=max_lines, + max_bytes=max_bytes + ) + + # Collect complete lines that fit + output_lines_arr = [] + output_bytes_count = 0 + truncated_by = "lines" + + for i, line in enumerate(lines): + if i >= max_lines: + break + + # Calculate line bytes (add 1 for newline if not first line) + line_bytes = len(line.encode('utf-8')) + (1 if i > 0 else 0) + + if output_bytes_count + line_bytes > max_bytes: + truncated_by = "bytes" + break + + output_lines_arr.append(line) + output_bytes_count += line_bytes + + # If exited due to line limit + if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes: + truncated_by = "lines" + + output_content = '\n'.join(output_lines_arr) + final_output_bytes = len(output_content.encode('utf-8')) + + return TruncationResult( + content=output_content, + truncated=True, + truncated_by=truncated_by, + total_lines=total_lines, + total_bytes=total_bytes, + output_lines=len(output_lines_arr), + output_bytes=final_output_bytes, + last_line_partial=False, + first_line_exceeds_limit=False, + max_lines=max_lines, + max_bytes=max_bytes + ) + + +def truncate_tail(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult: + """ + Truncate content from tail (keep last N lines/bytes). + Suitable for bash output where you want to see the ending content (errors, final results). + + If the last line of original content exceeds byte limit, may return partial first line. + + :param content: Content to truncate + :param max_lines: Maximum lines (default: 2000) + :param max_bytes: Maximum bytes (default: 50KB) + :return: Truncation result + """ + if max_lines is None: + max_lines = DEFAULT_MAX_LINES + if max_bytes is None: + max_bytes = DEFAULT_MAX_BYTES + + total_bytes = len(content.encode('utf-8')) + lines = content.split('\n') + total_lines = len(lines) + + # Check if no truncation is needed + if total_lines <= max_lines and total_bytes <= max_bytes: + return TruncationResult( + content=content, + truncated=False, + truncated_by=None, + total_lines=total_lines, + total_bytes=total_bytes, + output_lines=total_lines, + output_bytes=total_bytes, + last_line_partial=False, + first_line_exceeds_limit=False, + max_lines=max_lines, + max_bytes=max_bytes + ) + + # Work backwards from the end + output_lines_arr = [] + output_bytes_count = 0 + truncated_by = "lines" + last_line_partial = False + + for i in range(len(lines) - 1, -1, -1): + if len(output_lines_arr) >= max_lines: + break + + line = lines[i] + # Calculate line bytes (add newline if not the first added line) + line_bytes = len(line.encode('utf-8')) + (1 if len(output_lines_arr) > 0 else 0) + + if output_bytes_count + line_bytes > max_bytes: + truncated_by = "bytes" + # Edge case: if we haven't added any lines yet and this line exceeds maxBytes, + # take the end portion of this line + if len(output_lines_arr) == 0: + truncated_line = _truncate_string_to_bytes_from_end(line, max_bytes) + output_lines_arr.insert(0, truncated_line) + output_bytes_count = len(truncated_line.encode('utf-8')) + last_line_partial = True + break + + output_lines_arr.insert(0, line) + output_bytes_count += line_bytes + + # If exited due to line limit + if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes: + truncated_by = "lines" + + output_content = '\n'.join(output_lines_arr) + final_output_bytes = len(output_content.encode('utf-8')) + + return TruncationResult( + content=output_content, + truncated=True, + truncated_by=truncated_by, + total_lines=total_lines, + total_bytes=total_bytes, + output_lines=len(output_lines_arr), + output_bytes=final_output_bytes, + last_line_partial=last_line_partial, + first_line_exceeds_limit=False, + max_lines=max_lines, + max_bytes=max_bytes + ) + + +def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str: + """ + Truncate string to fit byte limit (from end). + Properly handles multi-byte UTF-8 characters. + + :param text: String to truncate + :param max_bytes: Maximum bytes + :return: Truncated string + """ + encoded = text.encode('utf-8') + if len(encoded) <= max_bytes: + return text + + # Start from end, skip back maxBytes + start = len(encoded) - max_bytes + + # Find valid UTF-8 boundary (character start) + while start < len(encoded) and (encoded[start] & 0xC0) == 0x80: + start += 1 + + return encoded[start:].decode('utf-8', errors='ignore') + + +def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> tuple[str, bool]: + """ + Truncate single line to max characters, add [truncated] suffix. + Used for grep match lines. + + :param line: Line to truncate + :param max_chars: Maximum characters + :return: (truncated text, whether truncated) + """ + if len(line) <= max_chars: + return line, False + return f"{line[:max_chars]}... [truncated]", True diff --git a/agent/tools/write/__init__.py b/agent/tools/write/__init__.py new file mode 100644 index 0000000..b9cd426 --- /dev/null +++ b/agent/tools/write/__init__.py @@ -0,0 +1,3 @@ +from .write import Write + +__all__ = ['Write'] diff --git a/agent/tools/write/write.py b/agent/tools/write/write.py new file mode 100644 index 0000000..a246040 --- /dev/null +++ b/agent/tools/write/write.py @@ -0,0 +1,91 @@ +""" +Write tool - Write file content +Creates or overwrites files, automatically creates parent directories +""" + +import os +from typing import Dict, Any +from pathlib import Path + +from agent.tools.base_tool import BaseTool, ToolResult + + +class Write(BaseTool): + """Tool for writing file content""" + + name: str = "write" + description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories." + + params: dict = { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file to write (relative or absolute)" + }, + "content": { + "type": "string", + "description": "Content to write to the file" + } + }, + "required": ["path", "content"] + } + + def __init__(self, config: dict = None): + self.config = config or {} + self.cwd = self.config.get("cwd", os.getcwd()) + + def execute(self, args: Dict[str, Any]) -> ToolResult: + """ + Execute file write operation + + :param args: Contains file path and content + :return: Operation result + """ + path = args.get("path", "").strip() + content = args.get("content", "") + + if not path: + return ToolResult.fail("Error: path parameter is required") + + # Resolve path + absolute_path = self._resolve_path(path) + + try: + # Create parent directory (if needed) + parent_dir = os.path.dirname(absolute_path) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) + + # Write file + with open(absolute_path, 'w', encoding='utf-8') as f: + f.write(content) + + # Get bytes written + bytes_written = len(content.encode('utf-8')) + + result = { + "message": f"Successfully wrote {bytes_written} bytes to {path}", + "path": path, + "bytes_written": bytes_written + } + + return ToolResult.success(result) + + except PermissionError: + return ToolResult.fail(f"Error: Permission denied writing to {path}") + except Exception as e: + return ToolResult.fail(f"Error writing file: {str(e)}") + + def _resolve_path(self, path: str) -> str: + """ + Resolve path to absolute path + + :param path: Relative or absolute path + :return: Absolute path + """ + # Expand ~ to user home directory + path = os.path.expanduser(path) + if os.path.isabs(path): + return path + return os.path.abspath(os.path.join(self.cwd, path)) diff --git a/bot/claude/claude_ai_bot.py b/bot/claude/claude_ai_bot.py deleted file mode 100644 index faad274..0000000 --- a/bot/claude/claude_ai_bot.py +++ /dev/null @@ -1,222 +0,0 @@ -import re -import time -import json -import uuid -from curl_cffi import requests -from bot.bot import Bot -from bot.claude.claude_ai_session import ClaudeAiSession -from bot.openai.open_ai_image import OpenAIImage -from bot.session_manager import SessionManager -from bridge.context import Context, ContextType -from bridge.reply import Reply, ReplyType -from common.log import logger -from config import conf - - -class ClaudeAIBot(Bot, OpenAIImage): - def __init__(self): - super().__init__() - self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo") - self.claude_api_cookie = conf().get("claude_api_cookie") - self.proxy = conf().get("proxy") - self.con_uuid_dic = {} - if self.proxy: - self.proxies = { - "http": self.proxy, - "https": self.proxy - } - else: - self.proxies = None - self.error = "" - self.org_uuid = self.get_organization_id() - - def generate_uuid(self): - random_uuid = uuid.uuid4() - random_uuid_str = str(random_uuid) - formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}" - return formatted_uuid - - def reply(self, query, context: Context = None) -> Reply: - if context.type == ContextType.TEXT: - return self._chat(query, context) - elif context.type == ContextType.IMAGE_CREATE: - ok, res = self.create_img(query, 0) - if ok: - reply = Reply(ReplyType.IMAGE_URL, res) - else: - reply = Reply(ReplyType.ERROR, res) - return reply - else: - reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) - return reply - - def get_organization_id(self): - url = "https://claude.ai/api/organizations" - headers = { - 'User-Agent': - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0', - 'Accept-Language': 'en-US,en;q=0.5', - 'Referer': 'https://claude.ai/chats', - 'Content-Type': 'application/json', - 'Sec-Fetch-Dest': 'empty', - 'Sec-Fetch-Mode': 'cors', - 'Sec-Fetch-Site': 'same-origin', - 'Connection': 'keep-alive', - 'Cookie': f'{self.claude_api_cookie}' - } - try: - response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400) - res = json.loads(response.text) - uuid = res[0]['uuid'] - except: - if "App unavailable" in response.text: - logger.error("IP error: The IP is not allowed to be used on Claude") - self.error = "ip所在地区不被claude支持" - elif "Invalid authorization" in response.text: - logger.error("Cookie error: Invalid authorization of claude, check cookie please.") - self.error = "无法通过claude身份验证,请检查cookie" - return None - return uuid - - def conversation_share_check(self,session_id): - if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "": - con_uuid = conf().get("claude_uuid") - return con_uuid - if session_id not in self.con_uuid_dic: - self.con_uuid_dic[session_id] = self.generate_uuid() - self.create_new_chat(self.con_uuid_dic[session_id]) - return self.con_uuid_dic[session_id] - - def check_cookie(self): - flag = self.get_organization_id() - return flag - - def create_new_chat(self, con_uuid): - """ - 新建claude对话实体 - :param con_uuid: 对话id - :return: - """ - url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations" - payload = json.dumps({"uuid": con_uuid, "name": ""}) - headers = { - 'User-Agent': - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0', - 'Accept-Language': 'en-US,en;q=0.5', - 'Referer': 'https://claude.ai/chats', - 'Content-Type': 'application/json', - 'Origin': 'https://claude.ai', - 'DNT': '1', - 'Connection': 'keep-alive', - 'Cookie': self.claude_api_cookie, - 'Sec-Fetch-Dest': 'empty', - 'Sec-Fetch-Mode': 'cors', - 'Sec-Fetch-Site': 'same-origin', - 'TE': 'trailers' - } - response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400) - # Returns JSON of the newly created conversation information - return response.json() - - def _chat(self, query, context, retry_count=0) -> Reply: - """ - 发起对话请求 - :param query: 请求提示词 - :param context: 对话上下文 - :param retry_count: 当前递归重试次数 - :return: 回复 - """ - if retry_count >= 2: - # exit from retry 2 times - logger.warn("[CLAUDEAI] failed after maximum number of retry times") - return Reply(ReplyType.ERROR, "请再问我一次吧") - - try: - session_id = context["session_id"] - if self.org_uuid is None: - return Reply(ReplyType.ERROR, self.error) - - session = self.sessions.session_query(query, session_id) - con_uuid = self.conversation_share_check(session_id) - - model = conf().get("model") or "gpt-3.5-turbo" - # remove system message - if session.messages[0].get("role") == "system": - if model == "wenxin" or model == "claude": - session.messages.pop(0) - logger.info(f"[CLAUDEAI] query={query}") - - # do http request - base_url = "https://claude.ai" - payload = json.dumps({ - "completion": { - "prompt": f"{query}", - "timezone": "Asia/Kolkata", - "model": "claude-2" - }, - "organization_uuid": f"{self.org_uuid}", - "conversation_uuid": f"{con_uuid}", - "text": f"{query}", - "attachments": [] - }) - headers = { - 'User-Agent': - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0', - 'Accept': 'text/event-stream, text/event-stream', - 'Accept-Language': 'en-US,en;q=0.5', - 'Referer': 'https://claude.ai/chats', - 'Content-Type': 'application/json', - 'Origin': 'https://claude.ai', - 'DNT': '1', - 'Connection': 'keep-alive', - 'Cookie': f'{self.claude_api_cookie}', - 'Sec-Fetch-Dest': 'empty', - 'Sec-Fetch-Mode': 'cors', - 'Sec-Fetch-Site': 'same-origin', - 'TE': 'trailers' - } - - res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400) - if res.status_code == 200 or "pemission" in res.text: - # execute success - decoded_data = res.content.decode("utf-8") - decoded_data = re.sub('\n+', '\n', decoded_data).strip() - data_strings = decoded_data.split('\n') - completions = [] - for data_string in data_strings: - json_str = data_string[6:].strip() - data = json.loads(json_str) - if 'completion' in data: - completions.append(data['completion']) - - reply_content = ''.join(completions) - - if "rate limi" in reply_content: - logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time") - return Reply(ReplyType.ERROR, "对话达到系统速率限制,与cladue同步,请进入官网查看解除限制时间") - logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible") - self.sessions.session_reply(reply_content, session_id, 100) - return Reply(ReplyType.TEXT, reply_content) - else: - flag = self.check_cookie() - if flag == None: - return Reply(ReplyType.ERROR, self.error) - - response = res.json() - error = response.get("error") - logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, " - f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}") - - if res.status_code >= 500: - # server error, need retry - time.sleep(2) - logger.warn(f"[CLAUDE] do retry, times={retry_count}") - return self._chat(query, context, retry_count + 1) - return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧") - - except Exception as e: - logger.exception(e) - # retry - time.sleep(2) - logger.warn(f"[CLAUDE] do retry, times={retry_count}") - return self._chat(query, context, retry_count + 1) diff --git a/bot/claude/claude_ai_session.py b/bot/claude/claude_ai_session.py deleted file mode 100644 index ede9e51..0000000 --- a/bot/claude/claude_ai_session.py +++ /dev/null @@ -1,9 +0,0 @@ -from bot.session_manager import Session - - -class ClaudeAiSession(Session): - def __init__(self, session_id, system_prompt=None, model="claude"): - super().__init__(session_id, system_prompt) - self.model = model - # claude逆向不支持role prompt - # self.reset() diff --git a/bot/claudeapi/claude_api_bot.py b/bot/claudeapi/claude_api_bot.py index e3e7a2b..84be8d4 100644 --- a/bot/claudeapi/claude_api_bot.py +++ b/bot/claudeapi/claude_api_bot.py @@ -1,19 +1,18 @@ # encoding:utf-8 +import json import time -import openai -import openai.error -import anthropic +import requests +from bot.baidu.baidu_wenxin_session import BaiduWenxinSession from bot.bot import Bot from bot.openai.open_ai_image import OpenAIImage -from bot.baidu.baidu_wenxin_session import BaiduWenxinSession from bot.session_manager import SessionManager from bridge.context import ContextType from bridge.reply import Reply, ReplyType -from common.log import logger from common import const +from common.log import logger from config import conf user_session = dict() @@ -23,13 +22,9 @@ user_session = dict() class ClaudeAPIBot(Bot, OpenAIImage): def __init__(self): super().__init__() - proxy = conf().get("proxy", None) - base_url = conf().get("open_ai_api_base", None) # 复用"open_ai_api_base"参数作为base_url - self.claudeClient = anthropic.Anthropic( - api_key=conf().get("claude_api_key"), - proxies=proxy if proxy else None, - base_url=base_url if base_url else None - ) + self.api_key = conf().get("claude_api_key") + self.api_base = conf().get("open_ai_api_base") or "https://api.anthropic.com/v1" + self.proxy = conf().get("proxy", None) self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "text-davinci-003") def reply(self, query, context=None): @@ -73,39 +68,104 @@ class ClaudeAPIBot(Bot, OpenAIImage): reply = Reply(ReplyType.ERROR, retstring) return reply - def reply_text(self, session: BaiduWenxinSession, retry_count=0): + def reply_text(self, session: BaiduWenxinSession, retry_count=0, tools=None): try: actual_model = self._model_mapping(conf().get("model")) - response = self.claudeClient.messages.create( - model=actual_model, - max_tokens=4096, - system=conf().get("character_desc", ""), - messages=session.messages + + # Prepare headers + headers = { + "x-api-key": self.api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json" + } + + # Extract system prompt if present and prepare Claude-compatible messages + system_prompt = conf().get("character_desc", "") + claude_messages = [] + + for msg in session.messages: + if msg.get("role") == "system": + system_prompt = msg["content"] + else: + claude_messages.append(msg) + + # Prepare request data + data = { + "model": actual_model, + "messages": claude_messages, + "max_tokens": self._get_max_tokens(actual_model) + } + + if system_prompt: + data["system"] = system_prompt + + if tools: + data["tools"] = tools + + # Make HTTP request + proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None + response = requests.post( + f"{self.api_base}/messages", + headers=headers, + json=data, + proxies=proxies ) - # response = openai.Completion.create(prompt=str(session), **self.args) - res_content = response.content[0].text.strip().replace("<|endoftext|>", "") - total_tokens = response.usage.input_tokens+response.usage.output_tokens - completion_tokens = response.usage.output_tokens + + if response.status_code != 200: + raise Exception(f"API request failed: {response.status_code} - {response.text}") + + claude_response = response.json() + # Handle response content and tool calls + res_content = "" + tool_calls = [] + + content_blocks = claude_response.get("content", []) + for block in content_blocks: + if block.get("type") == "text": + res_content += block.get("text", "") + elif block.get("type") == "tool_use": + tool_calls.append({ + "id": block.get("id", ""), + "name": block.get("name", ""), + "arguments": block.get("input", {}) + }) + + res_content = res_content.strip().replace("<|endoftext|>", "") + usage = claude_response.get("usage", {}) + total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0) + completion_tokens = usage.get("output_tokens", 0) + logger.info("[CLAUDE_API] reply={}".format(res_content)) - return { + if tool_calls: + logger.info("[CLAUDE_API] tool_calls={}".format(tool_calls)) + + result = { "total_tokens": total_tokens, "completion_tokens": completion_tokens, "content": res_content, } + + if tool_calls: + result["tool_calls"] = tool_calls + + return result except Exception as e: need_retry = retry_count < 2 result = {"total_tokens": 0, "completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} - if isinstance(e, openai.error.RateLimitError): + + # Handle different types of errors + error_str = str(e).lower() + if "rate" in error_str or "limit" in error_str: logger.warn("[CLAUDE_API] RateLimitError: {}".format(e)) result["content"] = "提问太快啦,请休息一下再问我吧" if need_retry: time.sleep(20) - elif isinstance(e, openai.error.Timeout): + elif "timeout" in error_str: logger.warn("[CLAUDE_API] Timeout: {}".format(e)) result["content"] = "我没有收到你的消息" if need_retry: time.sleep(5) - elif isinstance(e, openai.error.APIConnectionError): + elif "connection" in error_str or "network" in error_str: logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e)) need_retry = False result["content"] = "我连接不到你的网络" @@ -116,7 +176,7 @@ class ClaudeAPIBot(Bot, OpenAIImage): if need_retry: logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1)) - return self.reply_text(session, retry_count + 1) + return self.reply_text(session, retry_count + 1, tools) else: return result @@ -130,3 +190,288 @@ class ClaudeAPIBot(Bot, OpenAIImage): elif model == "claude-3.5-sonnet": return const.CLAUDE_35_SONNET return model + + def _get_max_tokens(self, model: str) -> int: + """ + Get max_tokens for the model. + Reference from pi-mono: + - Claude 3.5/3.7: 8192 + - Claude 3 Opus: 4096 + - Default: 8192 + """ + if model and (model.startswith("claude-3-5") or model.startswith("claude-3-7")): + return 8192 + elif model and model.startswith("claude-3") and "opus" in model: + return 4096 + elif model and (model.startswith("claude-sonnet-4") or model.startswith("claude-opus-4")): + return 64000 + return 8192 + + def call_with_tools(self, messages, tools=None, stream=False, **kwargs): + """ + Call Claude API with tool support for agent integration + + Args: + messages: List of messages + tools: List of tool definitions + stream: Whether to use streaming + **kwargs: Additional parameters + + Returns: + Formatted response compatible with OpenAI format or generator for streaming + """ + actual_model = self._model_mapping(conf().get("model")) + + # Extract system prompt from messages if present + system_prompt = kwargs.get("system", conf().get("character_desc", "")) + claude_messages = [] + + for msg in messages: + if msg.get("role") == "system": + system_prompt = msg["content"] + else: + claude_messages.append(msg) + + request_params = { + "model": actual_model, + "max_tokens": kwargs.get("max_tokens", self._get_max_tokens(actual_model)), + "messages": claude_messages, + "stream": stream + } + + if system_prompt: + request_params["system"] = system_prompt + + if tools: + request_params["tools"] = tools + + try: + if stream: + return self._handle_stream_response(request_params) + else: + return self._handle_sync_response(request_params) + except Exception as e: + logger.error(f"Claude API call error: {e}") + if stream: + # Return error generator for stream + def error_generator(): + yield { + "error": True, + "message": str(e), + "status_code": 500 + } + + return error_generator() + else: + # Return error response for sync + return { + "error": True, + "message": str(e), + "status_code": 500 + } + + def _handle_sync_response(self, request_params): + """Handle synchronous Claude API response""" + # Prepare headers + headers = { + "x-api-key": self.api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json" + } + + # Make HTTP request + proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None + response = requests.post( + f"{self.api_base}/messages", + headers=headers, + json=request_params, + proxies=proxies + ) + + if response.status_code != 200: + raise Exception(f"API request failed: {response.status_code} - {response.text}") + + claude_response = response.json() + + # Extract content blocks + text_content = "" + tool_calls = [] + + content_blocks = claude_response.get("content", []) + for block in content_blocks: + if block.get("type") == "text": + text_content += block.get("text", "") + elif block.get("type") == "tool_use": + tool_calls.append({ + "id": block.get("id", ""), + "type": "function", + "function": { + "name": block.get("name", ""), + "arguments": json.dumps(block.get("input", {})) + } + }) + + # Build message in OpenAI format + message = { + "role": "assistant", + "content": text_content + } + if tool_calls: + message["tool_calls"] = tool_calls + + # Format response to match OpenAI structure + usage = claude_response.get("usage", {}) + formatted_response = { + "id": claude_response.get("id", ""), + "object": "chat.completion", + "created": int(time.time()), + "model": claude_response.get("model", request_params["model"]), + "choices": [ + { + "index": 0, + "message": message, + "finish_reason": claude_response.get("stop_reason", "stop") + } + ], + "usage": { + "prompt_tokens": usage.get("input_tokens", 0), + "completion_tokens": usage.get("output_tokens", 0), + "total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0) + } + } + + return formatted_response + + def _handle_stream_response(self, request_params): + """Handle streaming Claude API response using HTTP requests""" + # Prepare headers + headers = { + "x-api-key": self.api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json" + } + + # Add stream parameter + request_params["stream"] = True + + # Track tool use state + tool_uses_map = {} # {index: {id, name, input}} + current_tool_use_index = -1 + + try: + # Make streaming HTTP request + proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None + response = requests.post( + f"{self.api_base}/messages", + headers=headers, + json=request_params, + proxies=proxies, + stream=True + ) + + if response.status_code != 200: + error_text = response.text + try: + error_data = json.loads(error_text) + error_msg = error_data.get("error", {}).get("message", error_text) + except: + error_msg = error_text or "Unknown error" + + yield { + "error": True, + "status_code": response.status_code, + "message": error_msg + } + return + + # Process streaming response + for line in response.iter_lines(): + if line: + line = line.decode('utf-8') + if line.startswith('data: '): + line = line[6:] # Remove 'data: ' prefix + if line == '[DONE]': + break + try: + event = json.loads(line) + event_type = event.get("type") + + if event_type == "content_block_start": + # New content block + block = event.get("content_block", {}) + if block.get("type") == "tool_use": + current_tool_use_index = event.get("index", 0) + tool_uses_map[current_tool_use_index] = { + "id": block.get("id", ""), + "name": block.get("name", ""), + "input": "" + } + + elif event_type == "content_block_delta": + delta = event.get("delta", {}) + delta_type = delta.get("type") + + if delta_type == "text_delta": + # Text content + content = delta.get("text", "") + yield { + "id": event.get("id", ""), + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": request_params["model"], + "choices": [{ + "index": 0, + "delta": {"content": content}, + "finish_reason": None + }] + } + + elif delta_type == "input_json_delta": + # Tool input accumulation + if current_tool_use_index >= 0: + tool_uses_map[current_tool_use_index]["input"] += delta.get("partial_json", "") + + elif event_type == "message_delta": + # Message complete - yield tool calls if any + if tool_uses_map: + for idx in sorted(tool_uses_map.keys()): + tool_data = tool_uses_map[idx] + yield { + "id": event.get("id", ""), + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": request_params["model"], + "choices": [{ + "index": 0, + "delta": { + "tool_calls": [{ + "index": idx, + "id": tool_data["id"], + "type": "function", + "function": { + "name": tool_data["name"], + "arguments": tool_data["input"] + } + }] + }, + "finish_reason": None + }] + } + + except json.JSONDecodeError: + continue + + except requests.RequestException as e: + logger.error(f"Claude streaming request error: {e}") + yield { + "error": True, + "message": f"Connection error: {str(e)}", + "status_code": 0 + } + except Exception as e: + logger.error(f"Claude streaming error: {e}") + yield { + "error": True, + "message": str(e), + "status_code": 500 + } diff --git a/bridge/agent_bridge.py b/bridge/agent_bridge.py new file mode 100644 index 0000000..d9355ec --- /dev/null +++ b/bridge/agent_bridge.py @@ -0,0 +1,288 @@ +""" +Agent Bridge - Integrates Agent system with existing COW bridge +""" + +from typing import Optional, List + +from agent.protocol import Agent, LLMModel, LLMRequest +from agent.tools import Calculator, CurrentTime, Read, Write, Edit, Bash, Grep, Find, Ls +from bridge.bridge import Bridge +from bridge.context import Context +from bridge.reply import Reply, ReplyType +from common import const +from common.log import logger + + +class AgentLLMModel(LLMModel): + """ + LLM Model adapter that uses COW's existing bot infrastructure + """ + + def __init__(self, bridge: Bridge, bot_type: str = "chat"): + # Get model name directly from config + from config import conf + model_name = conf().get("model", const.GPT_41) + super().__init__(model=model_name) + self.bridge = bridge + self.bot_type = bot_type + self._bot = None + + @property + def bot(self): + """Lazy load the bot""" + if self._bot is None: + self._bot = self.bridge.get_bot(self.bot_type) + return self._bot + + def call(self, request: LLMRequest): + """ + Call the model using COW's bot infrastructure + """ + try: + # For non-streaming calls, we'll use the existing reply method + # This is a simplified implementation + if hasattr(self.bot, 'call_with_tools'): + # Use tool-enabled call if available + kwargs = { + 'messages': request.messages, + 'tools': getattr(request, 'tools', None), + 'stream': False + } + # Only pass max_tokens if it's explicitly set + if request.max_tokens is not None: + kwargs['max_tokens'] = request.max_tokens + response = self.bot.call_with_tools(**kwargs) + return self._format_response(response) + else: + # Fallback to regular call + # This would need to be implemented based on your specific needs + raise NotImplementedError("Regular call not implemented yet") + + except Exception as e: + logger.error(f"AgentLLMModel call error: {e}") + raise + + def call_stream(self, request: LLMRequest): + """ + Call the model with streaming using COW's bot infrastructure + """ + try: + if hasattr(self.bot, 'call_with_tools'): + # Use tool-enabled streaming call if available + # Ensure max_tokens is an integer, use default if None + max_tokens = request.max_tokens if request.max_tokens is not None else 4096 + + # Extract system prompt if present + system_prompt = getattr(request, 'system', None) + + # Build kwargs for call_with_tools + kwargs = { + 'messages': request.messages, + 'tools': getattr(request, 'tools', None), + 'stream': True, + 'max_tokens': max_tokens + } + + # Add system prompt if present + if system_prompt: + kwargs['system'] = system_prompt + + stream = self.bot.call_with_tools(**kwargs) + + # Convert Claude stream format to our expected format + for chunk in stream: + yield self._format_stream_chunk(chunk) + else: + raise NotImplementedError("Streaming call not implemented yet") + + except Exception as e: + logger.error(f"AgentLLMModel call_stream error: {e}") + raise + + def _format_response(self, response): + """Format Claude response to our expected format""" + # This would need to be implemented based on Claude's response format + return response + + def _format_stream_chunk(self, chunk): + """Format Claude stream chunk to our expected format""" + # This would need to be implemented based on Claude's stream format + return chunk + + +class AgentBridge: + """ + Bridge class that integrates single super Agent with COW + """ + + def __init__(self, bridge: Bridge): + self.bridge = bridge + self.agent: Optional[Agent] = None + + def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent: + """ + Create the super agent with COW integration + + Args: + system_prompt: System prompt + tools: List of tools (optional) + **kwargs: Additional agent parameters + + Returns: + Agent instance + """ + # Create LLM model that uses COW's bot infrastructure + model = AgentLLMModel(self.bridge) + + # Default tools if none provided + if tools is None: + tools = [ + Calculator(), + CurrentTime(), + Read(), + Write(), + Edit(), + Bash(), + Grep(), + Find(), + Ls() + ] + + # Create the single super agent + self.agent = Agent( + system_prompt=system_prompt, + description=kwargs.get("description", "AI Super Agent"), + model=model, + tools=tools, + max_steps=kwargs.get("max_steps", 15), + output_mode=kwargs.get("output_mode", "logger") + ) + + return self.agent + + def get_agent(self) -> Optional[Agent]: + """Get the super agent, create if not exists""" + if self.agent is None: + self._init_default_agent() + return self.agent + + def _init_default_agent(self): + """Initialize default super agent with config and memory""" + from config import conf + import os + + # Get base system prompt from config + base_prompt = conf().get("character_desc", "你是一个AI助手") + + # Setup memory if enabled + memory_manager = None + memory_tools = [] + + try: + # Try to initialize memory system + from agent.memory import MemoryManager, MemoryConfig + from agent.tools import MemorySearchTool, MemoryGetTool + + # Create memory config directly with sensible defaults + workspace_root = os.path.expanduser("~/cow") + memory_config = MemoryConfig( + workspace_root=workspace_root, + embedding_provider="local", # Use local embedding (no API key needed) + embedding_model="all-MiniLM-L6-v2" + ) + + # Create memory manager with the config + memory_manager = MemoryManager(memory_config) + + # Create memory tools + memory_tools = [ + MemorySearchTool(memory_manager), + MemoryGetTool(memory_manager) + ] + + # Build memory guidance and add to system prompt + memory_guidance = memory_manager.build_memory_guidance( + lang="zh", + include_context=True + ) + system_prompt = base_prompt + "\n\n" + memory_guidance + + logger.info(f"[AgentBridge] Memory system initialized") + logger.info(f"[AgentBridge] Workspace: {memory_config.get_workspace()}") + + except Exception as e: + logger.warning(f"[AgentBridge] Memory system not available: {e}") + logger.info("[AgentBridge] Continuing without memory features") + system_prompt = base_prompt + import traceback + traceback.print_exc() + + logger.info("[AgentBridge] Initializing super agent") + + # Configure file tools to work in the correct workspace + file_config = {"cwd": workspace_root} if memory_manager else {} + + # Create default tools with workspace config + from agent.tools import Calculator, CurrentTime, Read, Write, Edit, Bash, Grep, Find, Ls + tools = [ + Calculator(), + CurrentTime(), + Read(config=file_config), + Write(config=file_config), + Edit(config=file_config), + Bash(config=file_config), + Grep(config=file_config), + Find(config=file_config), + Ls(config=file_config) + ] + + # Create agent with configured tools + agent = self.create_agent( + system_prompt=system_prompt, + tools=tools, + max_steps=50, + output_mode="logger" + ) + + # Attach memory manager to agent if available + if memory_manager: + agent.memory_manager = memory_manager + + # Add memory tools if available + if memory_tools: + for tool in memory_tools: + agent.add_tool(tool) + logger.info(f"[AgentBridge] Added {len(memory_tools)} memory tools") + + def agent_reply(self, query: str, context: Context = None, + on_event=None, clear_history: bool = False) -> Reply: + """ + Use super agent to reply to a query + + Args: + query: User query + context: COW context (optional) + on_event: Event callback (optional) + clear_history: Whether to clear conversation history + + Returns: + Reply object + """ + try: + # Get agent (will auto-initialize if needed) + agent = self.get_agent() + if not agent: + return Reply(ReplyType.ERROR, "Failed to initialize super agent") + + # Use agent's run_stream method + response = agent.run_stream( + user_message=query, + on_event=on_event, + clear_history=clear_history + ) + + return Reply(ReplyType.TEXT, response) + + except Exception as e: + logger.error(f"Agent reply error: {e}") + return Reply(ReplyType.ERROR, f"Agent error: {str(e)}") \ No newline at end of file diff --git a/bridge/bridge.py b/bridge/bridge.py index e556245..a7b93c4 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -23,7 +23,7 @@ class Bridge(object): if bot_type: self.btype["chat"] = bot_type else: - model_type = conf().get("model") or const.GPT35 + model_type = conf().get("model") or const.GPT_41_MINI if model_type in ["text-davinci-003"]: self.btype["chat"] = const.OPEN_AI if conf().get("use_azure_chatgpt", False): @@ -64,6 +64,7 @@ class Bridge(object): self.bots = {} self.chat_bots = {} + self._agent_bridge = None # 模型对应的接口 def get_bot(self, typename): @@ -104,3 +105,29 @@ class Bridge(object): 重置bot路由 """ self.__init__() + + def get_agent_bridge(self): + """ + Get agent bridge for agent-based conversations + """ + if self._agent_bridge is None: + from bridge.agent_bridge import AgentBridge + self._agent_bridge = AgentBridge(self) + return self._agent_bridge + + def fetch_agent_reply(self, query: str, context: Context = None, + on_event=None, clear_history: bool = False) -> Reply: + """ + Use super agent to handle the query + + Args: + query: User query + context: Context object + on_event: Event callback for streaming + clear_history: Whether to clear conversation history + + Returns: + Reply object + """ + agent_bridge = self.get_agent_bridge() + return agent_bridge.agent_reply(query, context, on_event, clear_history) diff --git a/channel/channel.py b/channel/channel.py index c225342..7f043e5 100644 --- a/channel/channel.py +++ b/channel/channel.py @@ -5,6 +5,8 @@ Message sending channel abstract class from bridge.bridge import Bridge from bridge.context import Context from bridge.reply import * +from common.log import logger +from config import conf class Channel(object): @@ -35,7 +37,30 @@ class Channel(object): raise NotImplementedError def build_reply_content(self, query, context: Context = None) -> Reply: - return Bridge().fetch_reply_content(query, context) + """ + Build reply content, using agent if enabled in config + """ + # Check if agent mode is enabled + use_agent = conf().get("agent", False) + + if use_agent: + try: + logger.info("[Channel] Using agent mode") + + # Use agent bridge to handle the query + return Bridge().fetch_agent_reply( + query=query, + context=context, + on_event=None, + clear_history=False + ) + except Exception as e: + logger.error(f"[Channel] Agent mode failed, fallback to normal mode: {e}") + # Fallback to normal mode if agent fails + return Bridge().fetch_reply_content(query, context) + else: + # Normal mode + return Bridge().fetch_reply_content(query, context) def build_voice_to_text(self, voice_file) -> Reply: return Bridge().fetch_voice_to_text(voice_file) diff --git a/channel/web/web_channel.py b/channel/web/web_channel.py index 7721728..5c8043f 100644 --- a/channel/web/web_channel.py +++ b/channel/web/web_channel.py @@ -150,9 +150,6 @@ class WebChannel(ChatChannel): Poll for responses using the session_id. """ try: - # 不记录轮询请求的日志 - web.ctx.log_request = False - data = web.data() json_data = json.loads(data) session_id = json_data.get('session_id') @@ -215,19 +212,20 @@ class WebChannel(ChatChannel): ) app = web.application(urls, globals(), autoreload=False) - # 禁用web.py的默认日志输出 - import io - from contextlib import redirect_stdout + # 完全禁用web.py的HTTP日志输出 + # 创建一个空的日志处理函数 + def null_log_function(status, environ): + pass - # 配置web.py的日志级别为ERROR,只显示错误 + # 替换web.py的日志函数 + web.httpserver.LogMiddleware.log = lambda self, status, environ: None + + # 配置web.py的日志级别为ERROR logging.getLogger("web").setLevel(logging.ERROR) - - # 禁用web.httpserver的日志 logging.getLogger("web.httpserver").setLevel(logging.ERROR) - # 临时重定向标准输出,捕获web.py的启动消息 - with redirect_stdout(io.StringIO()): - web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port)) + # 启动服务器 + web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port)) class RootHandler: diff --git a/config-template.json b/config-template.json index 476a5e0..001d2ca 100644 --- a/config-template.json +++ b/config-template.json @@ -3,6 +3,7 @@ "model": "", "open_ai_api_key": "YOUR API KEY", "claude_api_key": "YOUR API KEY", + "claude_api_base": "https://api.anthropic.com", "text_to_image": "dall-e-2", "voice_to_text": "openai", "text_to_voice": "openai", @@ -30,8 +31,9 @@ "expires_in_seconds": 3600, "character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", "temperature": 0.7, - "subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。", + "subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。", "use_linkai": false, "linkai_api_key": "", - "linkai_app_code": "" + "linkai_app_code": "", + "agent": true } diff --git a/config.py b/config.py index a02bfae..8b60285 100644 --- a/config.py +++ b/config.py @@ -1,10 +1,10 @@ # encoding:utf-8 +import copy import json import logging import os import pickle -import copy from common.log import logger @@ -183,6 +183,7 @@ available_setting = { "Minimax_group_id": "", "Minimax_base_url": "", "web_port": 9899, + "agent": False # 是否开启Agent模式 } diff --git a/memory/2026-01-29.md b/memory/2026-01-29.md new file mode 100644 index 0000000..44c18e3 --- /dev/null +++ b/memory/2026-01-29.md @@ -0,0 +1,5 @@ +# 2026-01-29 记录 + +## 老王的重要决定 +- 今天老王告诉我他决定要学AI了,这是一个重要的决策 +- 这可能会是他学习和职业发展的一个转折点 \ No newline at end of file diff --git a/memory/MEMORY.md b/memory/MEMORY.md new file mode 100644 index 0000000..d80e5cd --- /dev/null +++ b/memory/MEMORY.md @@ -0,0 +1,21 @@ +# Memory + +Long-term curated memories and preferences. + +## 用户信息 +- 用户名:老王 + +## 用户信息 +- 用户名:老王 + +## 用户偏好 +- 喜欢吃红烧肉 +- 爱打篮球 + +## 重要决策 +- 决定要学习AI(2026-01-29) + +## Notes + +- Important decisions and facts go here +- This is your long-term knowledge base \ No newline at end of file