mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-02-16 08:16:06 +08:00
Merge pull request #2652 from zhayujie/feat-cow-agent
feat: cow super agent
This commit is contained in:
@@ -6,5 +6,6 @@ Provides long-term memory capabilities with hybrid search (vector + keyword)
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
from agent.memory.embedding import create_embedding_provider
|
||||
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config']
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config', 'create_embedding_provider']
|
||||
|
||||
@@ -41,6 +41,10 @@ class MemoryConfig:
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
# Memory flush config (独立于模型 context window)
|
||||
flush_token_threshold: int = 50000 # 50K tokens 触发 flush
|
||||
flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮)
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
return Path(self.workspace_root)
|
||||
|
||||
@@ -4,20 +4,19 @@ Embedding providers for memory
|
||||
Supports OpenAI and local embedding models
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
import hashlib
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
@@ -31,7 +30,7 @@ class EmbeddingProvider(ABC):
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider"""
|
||||
"""OpenAI embedding provider using REST API"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
@@ -45,87 +44,58 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
|
||||
# Lazy import to avoid dependency issues
|
||||
try:
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI(api_key=api_key, base_url=api_base)
|
||||
except ImportError:
|
||||
raise ImportError("OpenAI package not installed. Install with: pip install openai")
|
||||
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
|
||||
# Set dimensions based on model
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI embedding API using requests"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
response = self.client.embeddings.create(
|
||||
input=text,
|
||||
model=self.model
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
result = self._call_api(text)
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.model
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
|
||||
result = self._call_api(texts)
|
||||
return [item["embedding"] for item in result["data"]]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||
"""Local embedding provider using sentence-transformers"""
|
||||
|
||||
def __init__(self, model: str = "all-MiniLM-L6-v2"):
|
||||
"""
|
||||
Initialize local embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name from sentence-transformers
|
||||
"""
|
||||
self.model_name = model
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer(model)
|
||||
self._dimensions = self.model.get_sentence_embedding_dimension()
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"sentence-transformers not installed. "
|
||||
"Install with: pip install sentence-transformers"
|
||||
)
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
embedding = self.model.encode(text, convert_to_numpy=True)
|
||||
return embedding.tolist()
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
embeddings = self.model.encode(texts, convert_to_numpy=True)
|
||||
return embeddings.tolist()
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embeddings to avoid recomputation"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
"""Get cached embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
@@ -156,20 +126,23 @@ def create_embedding_provider(
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
|
||||
Only supports OpenAI embedding via REST API.
|
||||
If initialization fails, caller should fall back to keyword-only search.
|
||||
|
||||
Args:
|
||||
provider: Provider name ("openai" or "local")
|
||||
model: Model name (provider-specific)
|
||||
api_key: API key for remote providers
|
||||
api_base: API base URL for remote providers
|
||||
provider: Provider name (only "openai" is supported)
|
||||
model: Model name (default: text-embedding-3-small)
|
||||
api_key: OpenAI API key (required)
|
||||
api_base: API base URL (default: https://api.openai.com/v1)
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not "openai" or api_key is missing
|
||||
"""
|
||||
if provider == "openai":
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
elif provider == "local":
|
||||
model = model or "all-MiniLM-L6-v2"
|
||||
return LocalEmbeddingProvider(model=model)
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding provider: {provider}")
|
||||
if provider != "openai":
|
||||
raise ValueError(f"Only 'openai' provider is supported, got: {provider}")
|
||||
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
|
||||
@@ -70,8 +70,9 @@ class MemoryManager:
|
||||
except Exception as e:
|
||||
# Embedding provider failed, but that's OK
|
||||
# We can still use keyword search and file operations
|
||||
print(f"⚠️ Warning: Embedding provider initialization failed: {e}")
|
||||
print(f"ℹ️ Memory will work with keyword search only (no semantic search)")
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}")
|
||||
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
@@ -135,13 +136,19 @@ class MemoryManager:
|
||||
# Perform vector search (if embedding provider available)
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2 # Get more candidates for merging
|
||||
)
|
||||
try:
|
||||
from common.log import logger
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2 # Get more candidates for merging
|
||||
)
|
||||
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Vector search failed: {e}")
|
||||
|
||||
# Perform keyword search
|
||||
keyword_results = self.storage.search_keyword(
|
||||
@@ -150,6 +157,8 @@ class MemoryManager:
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
@@ -356,30 +365,30 @@ class MemoryManager:
|
||||
|
||||
def should_flush_memory(
|
||||
self,
|
||||
current_tokens: int,
|
||||
context_window: int = 128000,
|
||||
reserve_tokens: int = 20000,
|
||||
soft_threshold: int = 4000
|
||||
current_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if memory flush should be triggered
|
||||
|
||||
独立的 flush 触发机制,不依赖模型 context window。
|
||||
使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
context_window: Model's context window size (default: 128K)
|
||||
reserve_tokens: Reserve tokens for compaction overhead (default: 20K)
|
||||
soft_threshold: Trigger N tokens before threshold (default: 4K)
|
||||
|
||||
Returns:
|
||||
True if memory flush should run
|
||||
"""
|
||||
return self.flush_manager.should_flush(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
soft_threshold=soft_threshold
|
||||
token_threshold=self.config.flush_token_threshold,
|
||||
turn_threshold=self.config.flush_turn_threshold
|
||||
)
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数(每次用户消息+AI回复算一轮)"""
|
||||
self.flush_manager.increment_turn()
|
||||
|
||||
async def execute_memory_flush(
|
||||
self,
|
||||
agent_executor,
|
||||
|
||||
@@ -46,14 +46,32 @@ class MemoryStorage:
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = db_path
|
||||
self.conn: Optional[sqlite3.Connection] = None
|
||||
self.fts5_available = False # Track FTS5 availability
|
||||
self._init_db()
|
||||
|
||||
def _check_fts5_support(self) -> bool:
|
||||
"""Check if SQLite has FTS5 support"""
|
||||
try:
|
||||
self.conn.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(test)")
|
||||
self.conn.execute("DROP TABLE IF EXISTS fts5_test")
|
||||
return True
|
||||
except sqlite3.OperationalError as e:
|
||||
if "no such module: fts5" in str(e):
|
||||
return False
|
||||
raise
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database with schema"""
|
||||
try:
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
# Check FTS5 support
|
||||
self.fts5_available = self._check_fts5_support()
|
||||
if not self.fts5_available:
|
||||
from common.log import logger
|
||||
logger.warning("[MemoryStorage] FTS5 not available, using LIKE-based keyword search")
|
||||
|
||||
# Check database integrity
|
||||
try:
|
||||
result = self.conn.execute("PRAGMA integrity_check").fetchone()
|
||||
@@ -125,43 +143,44 @@ class MemoryStorage:
|
||||
ON chunks(path, hash)
|
||||
""")
|
||||
|
||||
# Create FTS5 virtual table for keyword search
|
||||
# Use default unicode61 tokenizer (stable and compatible)
|
||||
# For CJK support, we'll use LIKE queries as fallback
|
||||
self.conn.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
||||
text,
|
||||
id UNINDEXED,
|
||||
user_id UNINDEXED,
|
||||
path UNINDEXED,
|
||||
source UNINDEXED,
|
||||
scope UNINDEXED,
|
||||
content='chunks',
|
||||
content_rowid='rowid'
|
||||
)
|
||||
""")
|
||||
|
||||
# Create triggers to keep FTS in sync
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
|
||||
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
||||
DELETE FROM chunks_fts WHERE rowid = old.rowid;
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
||||
UPDATE chunks_fts SET text = new.text, id = new.id,
|
||||
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
|
||||
WHERE rowid = new.rowid;
|
||||
END
|
||||
""")
|
||||
# Create FTS5 virtual table for keyword search (only if supported)
|
||||
if self.fts5_available:
|
||||
# Use default unicode61 tokenizer (stable and compatible)
|
||||
# For CJK support, we'll use LIKE queries as fallback
|
||||
self.conn.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
||||
text,
|
||||
id UNINDEXED,
|
||||
user_id UNINDEXED,
|
||||
path UNINDEXED,
|
||||
source UNINDEXED,
|
||||
scope UNINDEXED,
|
||||
content='chunks',
|
||||
content_rowid='rowid'
|
||||
)
|
||||
""")
|
||||
|
||||
# Create triggers to keep FTS in sync
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
|
||||
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
||||
DELETE FROM chunks_fts WHERE rowid = old.rowid;
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
||||
UPDATE chunks_fts SET text = new.text, id = new.id,
|
||||
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
|
||||
WHERE rowid = new.rowid;
|
||||
END
|
||||
""")
|
||||
|
||||
# Create files metadata table
|
||||
self.conn.execute("""
|
||||
@@ -301,21 +320,22 @@ class MemoryStorage:
|
||||
Keyword search using FTS5 + LIKE fallback
|
||||
|
||||
Strategy:
|
||||
1. Try FTS5 search first (good for English and word-based languages)
|
||||
2. If no results and query contains CJK characters, use LIKE search
|
||||
1. If FTS5 available: Try FTS5 search first (good for English and word-based languages)
|
||||
2. If no FTS5 or no results and query contains CJK: Use LIKE search
|
||||
"""
|
||||
if scopes is None:
|
||||
scopes = ["shared"]
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
# Try FTS5 search first
|
||||
fts_results = self._search_fts5(query, user_id, scopes, limit)
|
||||
if fts_results:
|
||||
return fts_results
|
||||
# Try FTS5 search first (if available)
|
||||
if self.fts5_available:
|
||||
fts_results = self._search_fts5(query, user_id, scopes, limit)
|
||||
if fts_results:
|
||||
return fts_results
|
||||
|
||||
# Fallback to LIKE search for CJK characters
|
||||
if MemoryStorage._contains_cjk(query):
|
||||
# Fallback to LIKE search (always for CJK, or if FTS5 not available)
|
||||
if not self.fts5_available or MemoryStorage._contains_cjk(query):
|
||||
return self._search_like(query, user_id, scopes, limit)
|
||||
|
||||
return []
|
||||
|
||||
@@ -41,46 +41,42 @@ class MemoryFlushManager:
|
||||
# Tracking
|
||||
self.last_flush_token_count: Optional[int] = None
|
||||
self.last_flush_timestamp: Optional[datetime] = None
|
||||
self.turn_count: int = 0 # 对话轮数计数器
|
||||
|
||||
def should_flush(
|
||||
self,
|
||||
current_tokens: int,
|
||||
context_window: int,
|
||||
reserve_tokens: int = 20000,
|
||||
soft_threshold: int = 4000
|
||||
current_tokens: int = 0,
|
||||
token_threshold: int = 50000,
|
||||
turn_threshold: int = 20
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if memory flush should be triggered
|
||||
|
||||
Similar to clawdbot's shouldRunMemoryFlush logic:
|
||||
threshold = contextWindow - reserveTokens - softThreshold
|
||||
独立的 flush 触发机制,不依赖模型 context window:
|
||||
- Token 阈值: 达到 50K tokens 时触发
|
||||
- 轮次阈值: 达到 20 轮对话时触发
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
context_window: Model's context window size
|
||||
reserve_tokens: Reserve tokens for compaction overhead
|
||||
soft_threshold: Trigger flush N tokens before threshold
|
||||
token_threshold: Token threshold to trigger flush (default: 50K)
|
||||
turn_threshold: Turn threshold to trigger flush (default: 20)
|
||||
|
||||
Returns:
|
||||
True if flush should run
|
||||
"""
|
||||
if current_tokens <= 0:
|
||||
return False
|
||||
# 检查 token 阈值
|
||||
if current_tokens > 0 and current_tokens >= token_threshold:
|
||||
# 避免重复 flush
|
||||
if self.last_flush_token_count is not None:
|
||||
if current_tokens <= self.last_flush_token_count + 5000:
|
||||
return False
|
||||
return True
|
||||
|
||||
threshold = max(0, context_window - reserve_tokens - soft_threshold)
|
||||
if threshold <= 0:
|
||||
return False
|
||||
# 检查轮次阈值
|
||||
if self.turn_count >= turn_threshold:
|
||||
return True
|
||||
|
||||
# Check if we've crossed the threshold
|
||||
if current_tokens < threshold:
|
||||
return False
|
||||
|
||||
# Avoid duplicate flush in same compaction cycle
|
||||
if self.last_flush_token_count is not None:
|
||||
if current_tokens <= self.last_flush_token_count + soft_threshold:
|
||||
return False
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
@@ -130,7 +126,12 @@ class MemoryFlushManager:
|
||||
f"Pre-compaction memory flush. "
|
||||
f"Store durable memories now (use memory/{today}.md for daily notes; "
|
||||
f"create memory/ if needed). "
|
||||
f"If nothing to store, reply with NO_REPLY."
|
||||
f"\n\n"
|
||||
f"重要提示:\n"
|
||||
f"- MEMORY.md: 记录最核心、最常用的信息(例如重要规则、偏好、决策、要求等)\n"
|
||||
f" 如果 MEMORY.md 过长,可以精简或移除不再重要的内容。避免冗长描述,用关键词和要点形式记录\n"
|
||||
f"- memory/{today}.md: 记录当天发生的事件、关键信息、经验教训、对话过程摘要等,突出重点\n"
|
||||
f"- 如果没有重要内容需要记录,回复 NO_REPLY\n"
|
||||
)
|
||||
|
||||
def create_flush_system_prompt(self) -> str:
|
||||
@@ -142,6 +143,20 @@ class MemoryFlushManager:
|
||||
return (
|
||||
"Pre-compaction memory flush turn. "
|
||||
"The session is near auto-compaction; capture durable memories to disk. "
|
||||
"\n\n"
|
||||
"记忆写入原则:\n"
|
||||
"1. MEMORY.md 精简原则: 只记录核心信息(<2000 tokens)\n"
|
||||
" - 记录重要规则、偏好、决策、要求等需要长期记住的关键信息,无需记录过多细节\n"
|
||||
" - 如果 MEMORY.md 过长,可以根据需要精简或删除过时内容\n"
|
||||
"\n"
|
||||
"2. 天级记忆 (memory/YYYY-MM-DD.md):\n"
|
||||
" - 记录当天的重要事件、关键信息、经验教训、对话过程摘要等,确保核心信息点被完整记录\n"
|
||||
"\n"
|
||||
"3. 判断标准:\n"
|
||||
" - 这个信息未来会经常用到吗?→ MEMORY.md\n"
|
||||
" - 这是今天的重要事件或决策吗?→ memory/YYYY-MM-DD.md\n"
|
||||
" - 这是临时性的、不重要的内容吗?→ 不记录\n"
|
||||
"\n"
|
||||
"You may reply, but usually NO_REPLY is correct."
|
||||
)
|
||||
|
||||
@@ -180,6 +195,7 @@ class MemoryFlushManager:
|
||||
# Track flush
|
||||
self.last_flush_token_count = current_tokens
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
self.turn_count = 0 # 重置轮数计数器
|
||||
|
||||
return True
|
||||
|
||||
@@ -187,6 +203,10 @@ class MemoryFlushManager:
|
||||
print(f"Memory flush failed: {e}")
|
||||
return False
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数"""
|
||||
self.turn_count += 1
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get memory flush status"""
|
||||
return {
|
||||
|
||||
@@ -179,8 +179,8 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
tool_map = {}
|
||||
tool_descriptions = {
|
||||
"read": "读取文件内容",
|
||||
"write": "创建或覆盖文件",
|
||||
"edit": "精确编辑文件内容",
|
||||
"write": "创建新文件或完全覆盖现有文件(会删除原内容!追加内容请用 edit)。注意:单次 write 内容不要超过 10KB,超大文件请分步创建",
|
||||
"edit": "精确编辑文件(追加、修改、删除部分内容)",
|
||||
"ls": "列出目录内容",
|
||||
"grep": "在文件中搜索内容",
|
||||
"find": "按照模式查找文件",
|
||||
@@ -237,11 +237,13 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
"叙述要求: 保持简洁、信息密度高,避免重复显而易见的步骤。",
|
||||
"",
|
||||
"完成标准:",
|
||||
"- 确保用户的需求得到实际解决,而不仅仅是制定计划",
|
||||
"- 当任务需要多次工具调用时,持续推进直到完成",
|
||||
"- 确保用户的需求得到实际解决,而不仅仅是制定计划。",
|
||||
"- 当任务需要多次工具调用时,持续推进直到完成, 解决完后向用户报告结果或回复用户的问题",
|
||||
"- 每次工具调用后,评估是否已获得足够信息来推进或完成任务",
|
||||
"- 避免重复调用相同的工具和相同参数获取相同的信息,除非用户明确要求",
|
||||
"",
|
||||
"**安全提醒**: 回复中涉及密钥、令牌、密码等敏感信息时,必须脱敏处理,禁止直接显示完整内容。",
|
||||
"",
|
||||
])
|
||||
|
||||
return lines
|
||||
@@ -305,17 +307,21 @@ def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], langu
|
||||
"",
|
||||
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
|
||||
"",
|
||||
"1. 使用 `memory_search` 在 MEMORY.md 和 memory/*.md 中搜索",
|
||||
"2. 然后使用 `memory_get` 只拉取需要的行",
|
||||
"3. 如果搜索后仍然信心不足,告诉用户你已经检查过了",
|
||||
"1. 不确定记忆文件位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
|
||||
"2. 已知文件位置 → 直接用 `memory_get` 读取相应的行 (例如:MEMORY.md, memory/YYYY-MM-DD.md)",
|
||||
"3. search 无结果 → 尝试用 `memory_get` 读取MEMORY.md及最近两天记忆文件",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆,包含重要的背景信息",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的对话和事件",
|
||||
"- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息",
|
||||
"",
|
||||
"**使用原则**:",
|
||||
"- 自然使用记忆,就像你本来就知道",
|
||||
"- 不要主动提起或列举记忆,除非用户明确询问",
|
||||
"**写入记忆**:",
|
||||
"- 追加内容 → `edit` 工具,oldText 留空",
|
||||
"- 修改内容 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- 新建文件 → `write` 工具",
|
||||
"- **禁止写入敏感信息**:API密钥、令牌等敏感信息严禁写入记忆文件",
|
||||
"",
|
||||
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
|
||||
"",
|
||||
]
|
||||
|
||||
@@ -385,8 +391,8 @@ def _build_workspace_section(workspace_dir: str, language: str, is_first_convers
|
||||
"",
|
||||
"**交流规范**:",
|
||||
"",
|
||||
"- 在所有对话中,无需提及技术细节(如 SOUL.md、USER.md 等文件名,工具名称,配置等),除非用户明确询问",
|
||||
"- 用自然表达如「我已记住」而非「已更新 SOUL.md」",
|
||||
"- 在对话中,非必要不输出工作空间技术细节(如 SOUL.md、USER.md等文件名称,工具名称,配置等),除非用户明确询问",
|
||||
"- 例如用自然表达如「我已记住」而非「已更新 MEMORY.md」",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works
|
||||
_create_template_if_missing(agents_path, _get_agents_template())
|
||||
_create_template_if_missing(memory_path, _get_memory_template())
|
||||
|
||||
logger.info(f"[Workspace] Initialized workspace at: {workspace_dir}")
|
||||
logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}")
|
||||
|
||||
return WorkspaceFiles(
|
||||
soul_path=soul_path,
|
||||
@@ -270,14 +270,9 @@ def _get_agents_template() -> str:
|
||||
2. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项)
|
||||
3. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
|
||||
**重要**:
|
||||
- 爱好(唱歌、篮球等)→ MEMORY.md,不是 USER.md
|
||||
- 近期计划(下周要做什么)→ MEMORY.md,不是 USER.md
|
||||
- USER.md 只存放不会变的基本信息
|
||||
|
||||
## 安全
|
||||
|
||||
- 永远不要泄露私人数据
|
||||
- 永远不要泄露秘钥等私人数据
|
||||
- 不要在未经询问的情况下运行破坏性命令
|
||||
- 当有疑问时,先问
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
@@ -43,6 +44,7 @@ class Agent:
|
||||
self.output_mode = output_mode
|
||||
self.last_usage = None # Store last API response usage info
|
||||
self.messages = [] # Unified message history for stream mode
|
||||
self.messages_lock = threading.Lock() # Lock for thread-safe message operations
|
||||
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
||||
self.workspace_dir = workspace_dir # Workspace directory
|
||||
self.enable_skills = enable_skills # Skills enabled flag
|
||||
@@ -57,7 +59,7 @@ class Agent:
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
self.skill_manager = SkillManager(workspace_dir=workspace_dir)
|
||||
logger.info(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
|
||||
logger.debug(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SkillManager: {e}")
|
||||
|
||||
@@ -335,7 +337,8 @@ class Agent:
|
||||
"""
|
||||
# Clear history if requested
|
||||
if clear_history:
|
||||
self.messages = []
|
||||
with self.messages_lock:
|
||||
self.messages = []
|
||||
|
||||
# Get model to use
|
||||
if not self.model:
|
||||
@@ -344,7 +347,17 @@ class Agent:
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = self.get_full_system_prompt(skill_filter=skill_filter)
|
||||
|
||||
# Create stream executor with agent's message history
|
||||
# Create a copy of messages for this execution to avoid concurrent modification
|
||||
# Record the original length to track which messages are new
|
||||
with self.messages_lock:
|
||||
messages_copy = self.messages.copy()
|
||||
original_length = len(self.messages)
|
||||
|
||||
# Get max_context_turns from config
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 30)
|
||||
|
||||
# Create stream executor with copied message history
|
||||
executor = AgentStreamExecutor(
|
||||
agent=self,
|
||||
model=self.model,
|
||||
@@ -352,14 +365,21 @@ class Agent:
|
||||
tools=self.tools,
|
||||
max_turns=self.max_steps,
|
||||
on_event=on_event,
|
||||
messages=self.messages # Pass agent's message history
|
||||
messages=messages_copy, # Pass copied message history
|
||||
max_context_turns=max_context_turns
|
||||
)
|
||||
|
||||
# Execute
|
||||
response = executor.run_stream(user_message)
|
||||
|
||||
# Update agent's message history from executor
|
||||
self.messages = executor.messages
|
||||
# Append only the NEW messages from this execution (thread-safe)
|
||||
# This allows concurrent requests to both contribute to history
|
||||
with self.messages_lock:
|
||||
new_messages = executor.messages[original_length:]
|
||||
self.messages.extend(new_messages)
|
||||
|
||||
# Store executor reference for agent_bridge to access files_to_send
|
||||
self.stream_executor = executor
|
||||
|
||||
# Execute all post-process tools
|
||||
self._execute_post_process_tools()
|
||||
|
||||
@@ -7,9 +7,9 @@ import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class AgentStreamExecutor:
|
||||
@@ -31,7 +31,8 @@ class AgentStreamExecutor:
|
||||
tools: List[BaseTool],
|
||||
max_turns: int = 50,
|
||||
on_event: Optional[Callable] = None,
|
||||
messages: Optional[List[Dict]] = None
|
||||
messages: Optional[List[Dict]] = None,
|
||||
max_context_turns: int = 30
|
||||
):
|
||||
"""
|
||||
Initialize stream executor
|
||||
@@ -44,6 +45,7 @@ class AgentStreamExecutor:
|
||||
max_turns: Maximum number of turns
|
||||
on_event: Event callback function
|
||||
messages: Optional existing message history (for persistent conversations)
|
||||
max_context_turns: Maximum number of conversation turns to keep in context
|
||||
"""
|
||||
self.agent = agent
|
||||
self.model = model
|
||||
@@ -52,12 +54,16 @@ class AgentStreamExecutor:
|
||||
self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools
|
||||
self.max_turns = max_turns
|
||||
self.on_event = on_event
|
||||
self.max_context_turns = max_context_turns
|
||||
|
||||
# Message history - use provided messages or create new list
|
||||
self.messages = messages if messages is not None else []
|
||||
|
||||
# Tool failure tracking for retry protection
|
||||
self.tool_failure_history = [] # List of (tool_name, args_hash, success) tuples
|
||||
|
||||
# Track files to send (populated by read tool)
|
||||
self.files_to_send = [] # List of file metadata dicts
|
||||
|
||||
def _emit_event(self, event_type: str, data: dict = None):
|
||||
"""Emit event"""
|
||||
@@ -78,12 +84,15 @@ class AgentStreamExecutor:
|
||||
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.md5(args_str.encode()).hexdigest()[:8]
|
||||
|
||||
def _check_consecutive_failures(self, tool_name: str, args: dict) -> tuple[bool, str]:
|
||||
def _check_consecutive_failures(self, tool_name: str, args: dict) -> tuple[bool, str, bool]:
|
||||
"""
|
||||
Check if tool has failed too many times consecutively
|
||||
|
||||
Returns:
|
||||
(should_stop, reason)
|
||||
(should_stop, reason, is_critical)
|
||||
- should_stop: Whether to stop tool execution
|
||||
- reason: Reason for stopping
|
||||
- is_critical: Whether to abort entire conversation (True for 8+ failures)
|
||||
"""
|
||||
args_hash = self._hash_args(args)
|
||||
|
||||
@@ -99,7 +108,7 @@ class AgentStreamExecutor:
|
||||
break # Different tool or args, stop counting
|
||||
|
||||
if same_args_failures >= 3:
|
||||
return True, f"Tool '{tool_name}' with same arguments failed {same_args_failures} times consecutively. Stopping to prevent infinite loop."
|
||||
return True, f"工具 '{tool_name}' 使用相同参数连续失败 {same_args_failures} 次,停止执行以防止无限循环", False
|
||||
|
||||
# Count consecutive failures for same tool (any args)
|
||||
same_tool_failures = 0
|
||||
@@ -112,10 +121,15 @@ class AgentStreamExecutor:
|
||||
else:
|
||||
break # Different tool, stop counting
|
||||
|
||||
if same_tool_failures >= 6:
|
||||
return True, f"Tool '{tool_name}' failed {same_tool_failures} times consecutively (with any arguments). Stopping to prevent infinite loop."
|
||||
# Hard stop at 8 failures - abort with critical message
|
||||
if same_tool_failures >= 8:
|
||||
return True, f"抱歉,我没能完成这个任务。可能是我理解有误或者当前方法不太合适。\n\n建议你:\n• 换个方式描述需求试试\n• 把任务拆分成更小的步骤\n• 或者换个思路来解决", True
|
||||
|
||||
return False, ""
|
||||
# Warning at 6 failures
|
||||
if same_tool_failures >= 6:
|
||||
return True, f"工具 '{tool_name}' 连续失败 {same_tool_failures} 次(使用不同参数),停止执行以防止无限循环", False
|
||||
|
||||
return False, "", False
|
||||
|
||||
def _record_tool_result(self, tool_name: str, args: dict, success: bool):
|
||||
"""Record tool execution result for failure tracking"""
|
||||
@@ -136,10 +150,7 @@ class AgentStreamExecutor:
|
||||
Final response text
|
||||
"""
|
||||
# Log user message with model info
|
||||
logger.info(f"{'='*50}")
|
||||
logger.info(f"🤖 Model: {self.model.model}")
|
||||
logger.info(f"👤 用户: {user_message}")
|
||||
logger.info(f"{'='*50}")
|
||||
logger.info(f"🤖 {self.model.model} | 👤 {user_message}")
|
||||
|
||||
# Add user message (Claude format - use content blocks for consistency)
|
||||
self.messages.append({
|
||||
@@ -160,54 +171,74 @@ class AgentStreamExecutor:
|
||||
try:
|
||||
while turn < self.max_turns:
|
||||
turn += 1
|
||||
logger.info(f"第 {turn} 轮")
|
||||
logger.debug(f"第 {turn} 轮")
|
||||
self._emit_event("turn_start", {"turn": turn})
|
||||
|
||||
# Check if memory flush is needed (before calling LLM)
|
||||
# 使用独立的 flush 阈值(50K tokens 或 20 轮)
|
||||
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
|
||||
usage = self.agent.last_usage
|
||||
if usage and 'input_tokens' in usage:
|
||||
current_tokens = usage.get('input_tokens', 0)
|
||||
context_window = self.agent._get_model_context_window()
|
||||
# Use configured reserve_tokens or calculate based on context window
|
||||
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||
# Use smaller soft_threshold to trigger flush earlier (e.g., at 50K tokens)
|
||||
soft_threshold = 10000 # Trigger 10K tokens before limit
|
||||
|
||||
if self.agent.memory_manager.should_flush_memory(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
soft_threshold=soft_threshold
|
||||
current_tokens=current_tokens
|
||||
):
|
||||
self._emit_event("memory_flush_start", {
|
||||
"current_tokens": current_tokens,
|
||||
"threshold": context_window - reserve_tokens - soft_threshold
|
||||
"turn_count": self.agent.memory_manager.flush_manager.turn_count
|
||||
})
|
||||
|
||||
# TODO: Execute memory flush in background
|
||||
# This would require async support
|
||||
logger.info(f"Memory flush recommended at {current_tokens} tokens")
|
||||
logger.info(
|
||||
f"Memory flush recommended: tokens={current_tokens}, turns={self.agent.memory_manager.flush_manager.turn_count}")
|
||||
|
||||
# Call LLM
|
||||
assistant_msg, tool_calls = self._call_llm_stream()
|
||||
# Call LLM (enable retry_on_empty for better reliability)
|
||||
assistant_msg, tool_calls = self._call_llm_stream(retry_on_empty=True)
|
||||
final_response = assistant_msg
|
||||
|
||||
# No tool calls, end loop
|
||||
if not tool_calls:
|
||||
# 检查是否返回了空响应
|
||||
if not assistant_msg:
|
||||
logger.warning(f"[Agent] LLM returned empty response (no content and no tool calls)")
|
||||
logger.warning(f"[Agent] LLM returned empty response after retry (no content and no tool calls)")
|
||||
logger.info(f"[Agent] This usually happens when LLM thinks the task is complete after tool execution")
|
||||
|
||||
# 生成通用的友好提示
|
||||
final_response = (
|
||||
"抱歉,我暂时无法生成回复。请尝试换一种方式描述你的需求,或稍后再试。"
|
||||
)
|
||||
logger.info(f"Generated fallback response for empty LLM output")
|
||||
# 如果之前有工具调用,强制要求 LLM 生成文本回复
|
||||
if turn > 1:
|
||||
logger.info(f"[Agent] Requesting explicit response from LLM...")
|
||||
|
||||
# 添加一条消息,明确要求回复用户
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "请向用户说明刚才工具执行的结果或回答用户的问题。"
|
||||
}]
|
||||
})
|
||||
|
||||
# 再调用一次 LLM
|
||||
assistant_msg, tool_calls = self._call_llm_stream(retry_on_empty=False)
|
||||
final_response = assistant_msg
|
||||
|
||||
# 如果还是空,才使用 fallback
|
||||
if not assistant_msg and not tool_calls:
|
||||
logger.warning(f"[Agent] Still empty after explicit request")
|
||||
final_response = (
|
||||
"抱歉,我暂时无法生成回复。请尝试换一种方式描述你的需求,或稍后再试。"
|
||||
)
|
||||
logger.info(f"Generated fallback response for empty LLM output")
|
||||
else:
|
||||
# 第一轮就空回复,直接 fallback
|
||||
final_response = (
|
||||
"抱歉,我暂时无法生成回复。请尝试换一种方式描述你的需求,或稍后再试。"
|
||||
)
|
||||
logger.info(f"Generated fallback response for empty LLM output")
|
||||
else:
|
||||
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
|
||||
|
||||
logger.info(f"✅ 完成 (无工具调用)")
|
||||
logger.debug(f"✅ 完成 (无工具调用)")
|
||||
self._emit_event("turn_end", {
|
||||
"turn": turn,
|
||||
"has_tool_calls": False
|
||||
@@ -233,6 +264,20 @@ class AgentStreamExecutor:
|
||||
result = self._execute_tool(tool_call)
|
||||
tool_results.append(result)
|
||||
|
||||
# Check if this is a file to send (from read tool)
|
||||
if result.get("status") == "success" and isinstance(result.get("result"), dict):
|
||||
result_data = result.get("result")
|
||||
if result_data.get("type") == "file_to_send":
|
||||
# Store file metadata for later sending
|
||||
self.files_to_send.append(result_data)
|
||||
logger.info(f"📎 检测到待发送文件: {result_data.get('file_name', result_data.get('path'))}")
|
||||
|
||||
# Check for critical error - abort entire conversation
|
||||
if result.get("status") == "critical_error":
|
||||
logger.error(f"💥 检测到严重错误,终止对话")
|
||||
final_response = result.get('result', '任务执行失败')
|
||||
return final_response
|
||||
|
||||
# Log tool result in compact format
|
||||
status_emoji = "✅" if result.get("status") == "success" else "❌"
|
||||
result_data = result.get('result', '')
|
||||
@@ -305,11 +350,37 @@ class AgentStreamExecutor:
|
||||
})
|
||||
|
||||
if turn >= self.max_turns:
|
||||
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
|
||||
if not final_response:
|
||||
logger.warning(f"⚠️ 已达到最大决策步数限制: {self.max_turns}")
|
||||
|
||||
# Force model to summarize without tool calls
|
||||
logger.info(f"[Agent] Requesting summary from LLM after reaching max steps...")
|
||||
|
||||
# Add a system message to force summary
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": f"你已经执行了{turn}个决策步骤,达到了单次运行的最大步数限制。请总结一下你目前的执行过程和结果,告诉用户当前的进展情况。不要再调用工具,直接用文字回复。"
|
||||
}]
|
||||
})
|
||||
|
||||
# Call LLM one more time to get summary (without retry to avoid loops)
|
||||
try:
|
||||
summary_response, summary_tools = self._call_llm_stream(retry_on_empty=False)
|
||||
if summary_response:
|
||||
final_response = summary_response
|
||||
logger.info(f"💭 Summary: {summary_response[:150]}{'...' if len(summary_response) > 150 else ''}")
|
||||
else:
|
||||
# Fallback if model still doesn't respond
|
||||
final_response = (
|
||||
f"我已经执行了{turn}个决策步骤,达到了单次运行的步数上限。"
|
||||
"任务可能还未完全完成,建议你将任务拆分成更小的步骤,或者换一种方式描述需求。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get summary from LLM: {e}")
|
||||
final_response = (
|
||||
"抱歉,我在处理你的请求时遇到了一些困难,尝试了多次仍未能完成。"
|
||||
"请尝试简化你的问题,或换一种方式描述。"
|
||||
f"我已经执行了{turn}个决策步骤,达到了单次运行的步数上限。"
|
||||
"任务可能还未完全完成,建议你将任务拆分成更小的步骤,或者换一种方式描述需求。"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -318,9 +389,13 @@ class AgentStreamExecutor:
|
||||
raise
|
||||
|
||||
finally:
|
||||
logger.info(f"🏁 完成({turn}轮)")
|
||||
logger.debug(f"🏁 完成({turn}轮)")
|
||||
self._emit_event("agent_end", {"final_response": final_response})
|
||||
|
||||
# 每轮对话结束后增加计数(用户消息+AI回复=1轮)
|
||||
if self.agent.memory_manager:
|
||||
self.agent.memory_manager.increment_turn()
|
||||
|
||||
return final_response
|
||||
|
||||
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3) -> tuple[str, List[Dict]]:
|
||||
@@ -380,6 +455,7 @@ class AgentStreamExecutor:
|
||||
# Streaming response
|
||||
full_content = ""
|
||||
tool_calls_buffer = {} # {index: {id, name, arguments}}
|
||||
stop_reason = None # Track why the stream stopped
|
||||
|
||||
try:
|
||||
stream = self.model.call_stream(request)
|
||||
@@ -392,21 +468,47 @@ class AgentStreamExecutor:
|
||||
if isinstance(error_data, dict):
|
||||
error_msg = error_data.get("message", chunk.get("message", "Unknown error"))
|
||||
error_code = error_data.get("code", "")
|
||||
error_type = error_data.get("type", "")
|
||||
else:
|
||||
error_msg = chunk.get("message", str(error_data))
|
||||
error_code = ""
|
||||
error_type = ""
|
||||
|
||||
status_code = chunk.get("status_code", "N/A")
|
||||
logger.error(f"API Error: {error_msg} (Status: {status_code}, Code: {error_code})")
|
||||
logger.error(f"Full error chunk: {chunk}")
|
||||
|
||||
# Raise exception with full error message for retry logic
|
||||
raise Exception(f"{error_msg} (Status: {status_code})")
|
||||
# Log error with all available information
|
||||
logger.error(f"🔴 Stream API Error:")
|
||||
logger.error(f" Message: {error_msg}")
|
||||
logger.error(f" Status Code: {status_code}")
|
||||
logger.error(f" Error Code: {error_code}")
|
||||
logger.error(f" Error Type: {error_type}")
|
||||
logger.error(f" Full chunk: {chunk}")
|
||||
|
||||
# Check if this is a context overflow error (keyword-based, works for all models)
|
||||
# Don't rely on specific status codes as different providers use different codes
|
||||
error_msg_lower = error_msg.lower()
|
||||
is_overflow = any(keyword in error_msg_lower for keyword in [
|
||||
'context length exceeded', 'maximum context length', 'prompt is too long',
|
||||
'context overflow', 'context window', 'too large', 'exceeds model context',
|
||||
'request_too_large', 'request exceeds the maximum size', 'tokens exceed'
|
||||
])
|
||||
|
||||
if is_overflow:
|
||||
# Mark as context overflow for special handling
|
||||
raise Exception(f"[CONTEXT_OVERFLOW] {error_msg} (Status: {status_code})")
|
||||
else:
|
||||
# Raise exception with full error message for retry logic
|
||||
raise Exception(f"{error_msg} (Status: {status_code}, Code: {error_code}, Type: {error_type})")
|
||||
|
||||
# Parse chunk
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choice = chunk["choices"][0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Capture finish_reason if present
|
||||
finish_reason = choice.get("finish_reason")
|
||||
if finish_reason:
|
||||
stop_reason = finish_reason
|
||||
|
||||
# Handle text content
|
||||
if "content" in delta and delta["content"]:
|
||||
@@ -437,9 +539,46 @@ class AgentStreamExecutor:
|
||||
tool_calls_buffer[index]["arguments"] += func["arguments"]
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
error_str = str(e)
|
||||
error_str_lower = error_str.lower()
|
||||
|
||||
# Check if error is context overflow (non-retryable, needs session reset)
|
||||
# Method 1: Check for special marker (set in stream error handling above)
|
||||
is_context_overflow = '[context_overflow]' in error_str_lower
|
||||
|
||||
# Method 2: Fallback to keyword matching for non-stream errors
|
||||
if not is_context_overflow:
|
||||
is_context_overflow = any(keyword in error_str_lower for keyword in [
|
||||
'context length exceeded', 'maximum context length', 'prompt is too long',
|
||||
'context overflow', 'context window', 'too large', 'exceeds model context',
|
||||
'request_too_large', 'request exceeds the maximum size'
|
||||
])
|
||||
|
||||
# Check if error is message format error (incomplete tool_use/tool_result pairs)
|
||||
# This happens when previous conversation had tool failures
|
||||
is_message_format_error = any(keyword in error_str_lower for keyword in [
|
||||
'tool_use', 'tool_result', 'without', 'immediately after',
|
||||
'corresponding', 'must have', 'each'
|
||||
]) and 'status: 400' in error_str_lower
|
||||
|
||||
if is_context_overflow or is_message_format_error:
|
||||
error_type = "context overflow" if is_context_overflow else "message format error"
|
||||
logger.error(f"💥 {error_type} detected: {e}")
|
||||
# Clear message history to recover
|
||||
logger.warning("🔄 Clearing conversation history to recover")
|
||||
self.messages.clear()
|
||||
# Raise special exception with user-friendly message
|
||||
if is_context_overflow:
|
||||
raise Exception(
|
||||
"抱歉,对话历史过长导致上下文溢出。我已清空历史记录,请重新描述你的需求。"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"抱歉,之前的对话出现了问题。我已清空历史记录,请重新发送你的消息。"
|
||||
)
|
||||
|
||||
# Check if error is retryable (timeout, connection, rate limit, server busy, etc.)
|
||||
is_retryable = any(keyword in error_str for keyword in [
|
||||
is_retryable = any(keyword in error_str_lower for keyword in [
|
||||
'timeout', 'timed out', 'connection', 'network',
|
||||
'rate limit', 'overloaded', 'unavailable', 'busy', 'retry',
|
||||
'429', '500', '502', '503', '504', '512'
|
||||
@@ -469,15 +608,19 @@ class AgentStreamExecutor:
|
||||
try:
|
||||
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse tool arguments: {tc['arguments']}")
|
||||
args_preview = tc['arguments'][:200] if len(tc['arguments']) > 200 else tc['arguments']
|
||||
logger.error(f"Failed to parse tool arguments for {tc['name']}")
|
||||
logger.error(f"Arguments length: {len(tc['arguments'])} chars")
|
||||
logger.error(f"Arguments preview: {args_preview}...")
|
||||
logger.error(f"JSON decode error: {e}")
|
||||
|
||||
# Return a clear error message to the LLM instead of empty dict
|
||||
# This helps the LLM understand what went wrong
|
||||
tool_calls.append({
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"arguments": {},
|
||||
"_parse_error": f"Invalid JSON in tool arguments: {tc['arguments'][:100]}... Error: {str(e)}"
|
||||
"_parse_error": f"Invalid JSON in tool arguments: {args_preview}... Error: {str(e)}. Tip: For large content, consider splitting into smaller chunks or using a different approach."
|
||||
})
|
||||
continue
|
||||
|
||||
@@ -489,11 +632,12 @@ class AgentStreamExecutor:
|
||||
|
||||
# Check for empty response and retry once if enabled
|
||||
if retry_on_empty and not full_content and not tool_calls:
|
||||
logger.warning(f"⚠️ LLM returned empty response, retrying once...")
|
||||
logger.warning(f"⚠️ LLM returned empty response (stop_reason: {stop_reason}), retrying once...")
|
||||
self._emit_event("message_end", {
|
||||
"content": "",
|
||||
"tool_calls": [],
|
||||
"empty_retry": True
|
||||
"empty_retry": True,
|
||||
"stop_reason": stop_reason
|
||||
})
|
||||
# Retry without retry flag to avoid infinite loop
|
||||
return self._call_llm_stream(
|
||||
@@ -560,16 +704,25 @@ class AgentStreamExecutor:
|
||||
return result
|
||||
|
||||
# Check for consecutive failures (retry protection)
|
||||
should_stop, stop_reason = self._check_consecutive_failures(tool_name, arguments)
|
||||
should_stop, stop_reason, is_critical = self._check_consecutive_failures(tool_name, arguments)
|
||||
if should_stop:
|
||||
logger.error(f"🛑 {stop_reason}")
|
||||
self._record_tool_result(tool_name, arguments, False)
|
||||
# 返回错误给 LLM,让它尝试其他方法
|
||||
result = {
|
||||
"status": "error",
|
||||
"result": f"{stop_reason}\n\nThis approach is not working. Please try a completely different method or ask the user for more information/clarification.",
|
||||
"execution_time": 0
|
||||
}
|
||||
|
||||
if is_critical:
|
||||
# Critical failure - abort entire conversation
|
||||
result = {
|
||||
"status": "critical_error",
|
||||
"result": stop_reason,
|
||||
"execution_time": 0
|
||||
}
|
||||
else:
|
||||
# Normal failure - let LLM try different approach
|
||||
result = {
|
||||
"status": "error",
|
||||
"result": f"{stop_reason}\n\n当前方法行不通,请尝试完全不同的方法或向用户询问更多信息。",
|
||||
"execution_time": 0
|
||||
}
|
||||
return result
|
||||
|
||||
self._emit_event("tool_execution_start", {
|
||||
@@ -656,52 +809,174 @@ class AgentStreamExecutor:
|
||||
logger.warning(f"⚠️ Removing incomplete tool_use message from history")
|
||||
self.messages.pop()
|
||||
|
||||
def _identify_complete_turns(self) -> List[Dict]:
|
||||
"""
|
||||
识别完整的对话轮次
|
||||
|
||||
一个完整轮次包括:
|
||||
1. 用户消息(text)
|
||||
2. AI 回复(可能包含 tool_use)
|
||||
3. 工具结果(tool_result,如果有)
|
||||
4. 后续 AI 回复(如果有)
|
||||
|
||||
Returns:
|
||||
List of turns, each turn is a dict with 'messages' list
|
||||
"""
|
||||
turns = []
|
||||
current_turn = {'messages': []}
|
||||
|
||||
for msg in self.messages:
|
||||
role = msg.get('role')
|
||||
content = msg.get('content', [])
|
||||
|
||||
if role == 'user':
|
||||
# 检查是否是用户查询(不是工具结果)
|
||||
is_user_query = False
|
||||
if isinstance(content, list):
|
||||
is_user_query = any(
|
||||
block.get('type') == 'text'
|
||||
for block in content
|
||||
if isinstance(block, dict)
|
||||
)
|
||||
elif isinstance(content, str):
|
||||
is_user_query = True
|
||||
|
||||
if is_user_query:
|
||||
# 开始新轮次
|
||||
if current_turn['messages']:
|
||||
turns.append(current_turn)
|
||||
current_turn = {'messages': [msg]}
|
||||
else:
|
||||
# 工具结果,属于当前轮次
|
||||
current_turn['messages'].append(msg)
|
||||
else:
|
||||
# AI 回复,属于当前轮次
|
||||
current_turn['messages'].append(msg)
|
||||
|
||||
# 添加最后一个轮次
|
||||
if current_turn['messages']:
|
||||
turns.append(current_turn)
|
||||
|
||||
return turns
|
||||
|
||||
def _estimate_turn_tokens(self, turn: Dict) -> int:
|
||||
"""估算一个轮次的 tokens"""
|
||||
return sum(
|
||||
self.agent._estimate_message_tokens(msg)
|
||||
for msg in turn['messages']
|
||||
)
|
||||
|
||||
def _trim_messages(self):
|
||||
"""
|
||||
Trim message history to stay within context limits.
|
||||
Uses agent's context management configuration.
|
||||
智能清理消息历史,保持对话完整性
|
||||
|
||||
使用完整轮次作为清理单位,确保:
|
||||
1. 不会在对话中间截断
|
||||
2. 工具调用链(tool_use + tool_result)保持完整
|
||||
3. 每轮对话都是完整的(用户消息 + AI回复 + 工具调用)
|
||||
"""
|
||||
if not self.messages or not self.agent:
|
||||
return
|
||||
|
||||
# Get context window and reserve tokens from agent
|
||||
# Step 1: 识别完整轮次
|
||||
turns = self._identify_complete_turns()
|
||||
|
||||
if not turns:
|
||||
return
|
||||
|
||||
# Step 2: 轮次限制 - 保留最近 N 轮
|
||||
if len(turns) > self.max_context_turns:
|
||||
removed_turns = len(turns) - self.max_context_turns
|
||||
turns = turns[-self.max_context_turns:] # 保留最近的轮次
|
||||
|
||||
logger.info(
|
||||
f"💾 上下文轮次超限: {len(turns) + removed_turns} > {self.max_context_turns},"
|
||||
f"移除最早的 {removed_turns} 轮完整对话"
|
||||
)
|
||||
|
||||
# Step 3: Token 限制 - 保留完整轮次
|
||||
# Get context window from agent (based on model)
|
||||
context_window = self.agent._get_model_context_window()
|
||||
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||
max_tokens = context_window - reserve_tokens
|
||||
|
||||
# Estimate current tokens
|
||||
current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages)
|
||||
# Use configured max_context_tokens if available
|
||||
if hasattr(self.agent, 'max_context_tokens') and self.agent.max_context_tokens:
|
||||
max_tokens = self.agent.max_context_tokens
|
||||
else:
|
||||
# Reserve 10% for response generation
|
||||
reserve_tokens = int(context_window * 0.1)
|
||||
max_tokens = context_window - reserve_tokens
|
||||
|
||||
# Add system prompt tokens
|
||||
# Estimate system prompt tokens
|
||||
system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt})
|
||||
current_tokens += system_tokens
|
||||
available_tokens = max_tokens - system_tokens
|
||||
|
||||
# If under limit, no need to trim
|
||||
if current_tokens <= max_tokens:
|
||||
# Calculate current tokens
|
||||
current_tokens = sum(self._estimate_turn_tokens(turn) for turn in turns)
|
||||
|
||||
# If under limit, reconstruct messages and return
|
||||
if current_tokens + system_tokens <= max_tokens:
|
||||
# Reconstruct message list from turns
|
||||
new_messages = []
|
||||
for turn in turns:
|
||||
new_messages.extend(turn['messages'])
|
||||
|
||||
old_count = len(self.messages)
|
||||
self.messages = new_messages
|
||||
|
||||
# Log if we removed messages due to turn limit
|
||||
if old_count > len(self.messages):
|
||||
logger.info(f" 重建消息列表: {old_count} -> {len(self.messages)} 条消息")
|
||||
return
|
||||
|
||||
# Keep messages from newest, accumulating tokens
|
||||
available_tokens = max_tokens - system_tokens
|
||||
kept_messages = []
|
||||
# Token limit exceeded - keep complete turns from newest
|
||||
logger.info(
|
||||
f"🔄 上下文tokens超限: ~{current_tokens + system_tokens} > {max_tokens},"
|
||||
f"将按完整轮次移除最早的对话"
|
||||
)
|
||||
|
||||
# 从最新轮次开始,反向累加(保持完整轮次)
|
||||
kept_turns = []
|
||||
accumulated_tokens = 0
|
||||
|
||||
for msg in reversed(self.messages):
|
||||
msg_tokens = self.agent._estimate_message_tokens(msg)
|
||||
if accumulated_tokens + msg_tokens <= available_tokens:
|
||||
kept_messages.insert(0, msg)
|
||||
accumulated_tokens += msg_tokens
|
||||
min_turns = 3 # 尽量保留至少 3 轮,但不强制(避免超出 token 限制)
|
||||
|
||||
for i, turn in enumerate(reversed(turns)):
|
||||
turn_tokens = self._estimate_turn_tokens(turn)
|
||||
turns_from_end = i + 1
|
||||
|
||||
# 检查是否超出限制
|
||||
if accumulated_tokens + turn_tokens <= available_tokens:
|
||||
kept_turns.insert(0, turn)
|
||||
accumulated_tokens += turn_tokens
|
||||
else:
|
||||
# 超出限制
|
||||
# 如果还没有保留足够的轮次,且这是最后的机会,尝试保留
|
||||
if len(kept_turns) < min_turns and turns_from_end <= min_turns:
|
||||
# 检查是否严重超出(超出 20% 以上则放弃)
|
||||
overflow_ratio = (accumulated_tokens + turn_tokens - available_tokens) / available_tokens
|
||||
if overflow_ratio < 0.2: # 允许最多超出 20%
|
||||
kept_turns.insert(0, turn)
|
||||
accumulated_tokens += turn_tokens
|
||||
logger.debug(f" 为保留最少轮次,允许超出 {overflow_ratio*100:.1f}%")
|
||||
continue
|
||||
# 停止保留更早的轮次
|
||||
break
|
||||
|
||||
|
||||
# 重建消息列表
|
||||
new_messages = []
|
||||
for turn in kept_turns:
|
||||
new_messages.extend(turn['messages'])
|
||||
|
||||
old_count = len(self.messages)
|
||||
self.messages = kept_messages
|
||||
old_turn_count = len(turns)
|
||||
self.messages = new_messages
|
||||
new_count = len(self.messages)
|
||||
|
||||
new_turn_count = len(kept_turns)
|
||||
|
||||
if old_count > new_count:
|
||||
logger.info(
|
||||
f"Context trimmed: {old_count} -> {new_count} messages "
|
||||
f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, "
|
||||
f"limit: {max_tokens})"
|
||||
f" 移除了 {old_turn_count - new_turn_count} 轮对话 "
|
||||
f"({old_count} -> {new_count} 条消息,"
|
||||
f"~{current_tokens + system_tokens} -> ~{accumulated_tokens + system_tokens} tokens)"
|
||||
)
|
||||
|
||||
def _prepare_messages(self) -> List[Dict[str, Any]]:
|
||||
|
||||
@@ -70,33 +70,27 @@ def should_include_skill(
|
||||
entry: SkillEntry,
|
||||
config: Optional[Dict] = None,
|
||||
current_platform: Optional[str] = None,
|
||||
lenient: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a skill should be included based on requirements.
|
||||
|
||||
Similar to clawdbot's shouldIncludeSkill logic, but with lenient mode:
|
||||
- In lenient mode (default): Only check explicit disable and platform, ignore missing requirements
|
||||
- In strict mode: Check all requirements (binary, env vars, config)
|
||||
Simple rule: Skills are auto-enabled if their requirements are met.
|
||||
- Has required API keys → enabled
|
||||
- Missing API keys → disabled
|
||||
- Wrong keys → enabled but will fail at runtime (LLM will handle error)
|
||||
|
||||
:param entry: SkillEntry to check
|
||||
:param config: Configuration dictionary
|
||||
:param config: Configuration dictionary (currently unused, reserved for future)
|
||||
:param current_platform: Current platform (default: auto-detect)
|
||||
:param lenient: If True, ignore missing requirements and load all skills (default: True)
|
||||
:return: True if skill should be included
|
||||
"""
|
||||
metadata = entry.metadata
|
||||
skill_name = entry.skill.name
|
||||
skill_config = get_skill_config(config, skill_name)
|
||||
|
||||
# Always check if skill is explicitly disabled in config
|
||||
if skill_config and skill_config.get('enabled') is False:
|
||||
return False
|
||||
|
||||
# No metadata = always include (no requirements)
|
||||
if not metadata:
|
||||
return True
|
||||
|
||||
# Always check platform requirements (can't work on wrong platform)
|
||||
# Check platform requirements (can't work on wrong platform)
|
||||
if metadata.os:
|
||||
platform_name = current_platform or resolve_runtime_platform()
|
||||
# Map common platform names
|
||||
@@ -114,12 +108,7 @@ def should_include_skill(
|
||||
if metadata.always:
|
||||
return True
|
||||
|
||||
# In lenient mode, skip requirement checks and load all skills
|
||||
# Skills will fail gracefully at runtime if requirements are missing
|
||||
if lenient:
|
||||
return True
|
||||
|
||||
# Strict mode: Check all requirements
|
||||
# Check requirements
|
||||
if metadata.requires:
|
||||
# Check required binaries (all must be present)
|
||||
required_bins = metadata.requires.get('bins', [])
|
||||
@@ -133,29 +122,13 @@ def should_include_skill(
|
||||
if not has_any_binary(any_bins):
|
||||
return False
|
||||
|
||||
# Check environment variables (with config fallback)
|
||||
# Check environment variables (API keys)
|
||||
# Simple rule: All required env vars must be set
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
for env_name in required_env:
|
||||
# Check in order: 1) env var, 2) skill config env, 3) skill config apiKey (if primaryEnv)
|
||||
if has_env_var(env_name):
|
||||
continue
|
||||
if skill_config:
|
||||
# Check skill config env dict
|
||||
skill_env = skill_config.get('env', {})
|
||||
if isinstance(skill_env, dict) and env_name in skill_env:
|
||||
continue
|
||||
# Check skill config apiKey (if this is the primaryEnv)
|
||||
if metadata.primary_env == env_name and skill_config.get('apiKey'):
|
||||
continue
|
||||
# Requirement not satisfied
|
||||
return False
|
||||
|
||||
# Check config paths
|
||||
required_config = metadata.requires.get('config', [])
|
||||
if required_config and config:
|
||||
for config_path in required_config:
|
||||
if not is_config_path_truthy(config, config_path):
|
||||
if not has_env_var(env_name):
|
||||
# Missing required API key → disable skill
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -34,6 +34,7 @@ def format_skills_for_prompt(skills: List[Skill]) -> str:
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <location>{_escape_xml(skill.file_path)}</location>")
|
||||
lines.append(f" <base_dir>{_escape_xml(skill.base_dir)}</base_dir>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</available_skills>")
|
||||
|
||||
@@ -23,7 +23,22 @@ def parse_frontmatter(content: str) -> Dict[str, Any]:
|
||||
|
||||
frontmatter_text = match.group(1)
|
||||
|
||||
# Simple YAML-like parsing (supports key: value format)
|
||||
# Try to use PyYAML for proper YAML parsing
|
||||
try:
|
||||
import yaml
|
||||
frontmatter = yaml.safe_load(frontmatter_text)
|
||||
if not isinstance(frontmatter, dict):
|
||||
frontmatter = {}
|
||||
return frontmatter
|
||||
except ImportError:
|
||||
# Fallback to simple parsing if PyYAML not available
|
||||
pass
|
||||
except Exception:
|
||||
# If YAML parsing fails, fall back to simple parsing
|
||||
pass
|
||||
|
||||
# Simple YAML-like parsing (supports key: value format only)
|
||||
# This is a fallback for when PyYAML is not available
|
||||
for line in frontmatter_text.split('\n'):
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
@@ -72,10 +87,8 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
if not isinstance(metadata_raw, dict):
|
||||
return None
|
||||
|
||||
# Support both 'moltbot' and 'cow' keys for compatibility
|
||||
meta_obj = metadata_raw.get('moltbot') or metadata_raw.get('cow')
|
||||
if not meta_obj or not isinstance(meta_obj, dict):
|
||||
return None
|
||||
# Use metadata_raw directly (COW format)
|
||||
meta_obj = metadata_raw
|
||||
|
||||
# Parse install specs
|
||||
install_specs = []
|
||||
|
||||
@@ -137,6 +137,10 @@ class SkillLoader:
|
||||
name = frontmatter.get('name', parent_dir_name)
|
||||
description = frontmatter.get('description', '')
|
||||
|
||||
# Special handling for linkai-agent: dynamically load apps from config.json
|
||||
if name == 'linkai-agent':
|
||||
description = self._load_linkai_agent_description(skill_dir, description)
|
||||
|
||||
if not description or not description.strip():
|
||||
diagnostics.append(f"Skill {name} has no description: {file_path}")
|
||||
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
|
||||
@@ -161,6 +165,45 @@ class SkillLoader:
|
||||
|
||||
return LoadSkillsResult(skills=[skill], diagnostics=diagnostics)
|
||||
|
||||
def _load_linkai_agent_description(self, skill_dir: str, default_description: str) -> str:
|
||||
"""
|
||||
Dynamically load LinkAI agent description from config.json
|
||||
|
||||
:param skill_dir: Skill directory
|
||||
:param default_description: Default description from SKILL.md
|
||||
:return: Dynamic description with app list
|
||||
"""
|
||||
import json
|
||||
|
||||
config_path = os.path.join(skill_dir, "config.json")
|
||||
template_path = os.path.join(skill_dir, "config.json.template")
|
||||
|
||||
# Try to load config.json or fallback to template
|
||||
config_file = config_path if os.path.exists(config_path) else template_path
|
||||
|
||||
if not os.path.exists(config_file):
|
||||
return default_description
|
||||
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
apps = config.get("apps", [])
|
||||
if not apps:
|
||||
return default_description
|
||||
|
||||
# Build dynamic description with app details
|
||||
app_descriptions = "; ".join([
|
||||
f"{app['app_name']}({app['app_code']}: {app['app_description']})"
|
||||
for app in apps
|
||||
])
|
||||
|
||||
return f"Call LinkAI apps/workflows. {app_descriptions}"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[SkillLoader] Failed to load linkai-agent config: {e}")
|
||||
return default_description
|
||||
|
||||
def load_all_skills(
|
||||
self,
|
||||
managed_dir: Optional[str] = None,
|
||||
@@ -216,7 +259,7 @@ class SkillLoader:
|
||||
for diag in all_diagnostics[:5]: # Log first 5
|
||||
logger.debug(f" - {diag}")
|
||||
|
||||
logger.info(f"Loaded {len(skill_map)} skills from all sources")
|
||||
logger.debug(f"Loaded {len(skill_map)} skills from all sources")
|
||||
|
||||
return skill_map
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ class SkillManager:
|
||||
extra_dirs=self.extra_dirs,
|
||||
)
|
||||
|
||||
logger.info(f"SkillManager: Loaded {len(self.skills)} skills")
|
||||
logger.debug(f"SkillManager: Loaded {len(self.skills)} skills")
|
||||
|
||||
def get_skill(self, name: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
@@ -82,32 +82,24 @@ class SkillManager:
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
include_disabled: bool = False,
|
||||
check_requirements: bool = False, # Changed default to False for lenient loading
|
||||
lenient: bool = True, # New parameter for lenient mode
|
||||
) -> List[SkillEntry]:
|
||||
"""
|
||||
Filter skills based on criteria.
|
||||
|
||||
By default (lenient=True), all skills are loaded regardless of missing requirements.
|
||||
Skills will fail gracefully at runtime if requirements are not met.
|
||||
Simple rule: Skills are auto-enabled if requirements are met.
|
||||
- Has required API keys → included
|
||||
- Missing API keys → excluded
|
||||
|
||||
:param skill_filter: List of skill names to include (None = all)
|
||||
:param include_disabled: Whether to include skills with disable_model_invocation=True
|
||||
:param check_requirements: Whether to check skill requirements (default: False)
|
||||
:param lenient: If True, ignore missing requirements (default: True)
|
||||
:return: Filtered list of skill entries
|
||||
"""
|
||||
from agent.skills.config import should_include_skill
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Check requirements (platform, explicit disable, etc.)
|
||||
# In lenient mode, only checks platform and explicit disable
|
||||
if check_requirements or not lenient:
|
||||
entries = [e for e in entries if should_include_skill(e, self.config, lenient=lenient)]
|
||||
else:
|
||||
# Lenient mode: only check explicit disable and platform
|
||||
entries = [e for e in entries if should_include_skill(e, self.config, lenient=True)]
|
||||
# Check requirements (platform, binaries, env vars)
|
||||
entries = [e for e in entries if should_include_skill(e, self.config)]
|
||||
|
||||
# Apply skill filter
|
||||
if skill_filter is not None:
|
||||
|
||||
@@ -2,55 +2,57 @@
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
|
||||
# Import basic tools (no external dependencies)
|
||||
from agent.tools.calculator.calculator import Calculator
|
||||
|
||||
# Import file operation tools
|
||||
from agent.tools.read.read import Read
|
||||
from agent.tools.write.write import Write
|
||||
from agent.tools.edit.edit import Edit
|
||||
from agent.tools.bash.bash import Bash
|
||||
from agent.tools.grep.grep import Grep
|
||||
from agent.tools.find.find import Find
|
||||
from agent.tools.ls.ls import Ls
|
||||
from agent.tools.send.send import Send
|
||||
|
||||
# Import memory tools
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
# Import web tools
|
||||
from agent.tools.web_fetch.web_fetch import WebFetch
|
||||
|
||||
# Import tools with optional dependencies
|
||||
def _import_optional_tools():
|
||||
"""Import tools that have optional dependencies"""
|
||||
from common.log import logger
|
||||
tools = {}
|
||||
|
||||
# Google Search (requires requests)
|
||||
# EnvConfig Tool (requires python-dotenv)
|
||||
try:
|
||||
from agent.tools.google_search.google_search import GoogleSearch
|
||||
tools['GoogleSearch'] = GoogleSearch
|
||||
except ImportError:
|
||||
pass
|
||||
from agent.tools.env_config.env_config import EnvConfig
|
||||
tools['EnvConfig'] = EnvConfig
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"[Tools] EnvConfig tool not loaded - missing dependency: {e}\n"
|
||||
f" To enable environment variable management, run:\n"
|
||||
f" pip install python-dotenv>=1.0.0"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] EnvConfig tool failed to load: {e}")
|
||||
|
||||
# File Save (may have dependencies)
|
||||
# Scheduler Tool (requires croniter)
|
||||
try:
|
||||
from agent.tools.file_save.file_save import FileSave
|
||||
tools['FileSave'] = FileSave
|
||||
except ImportError:
|
||||
pass
|
||||
from agent.tools.scheduler.scheduler_tool import SchedulerTool
|
||||
tools['SchedulerTool'] = SchedulerTool
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"[Tools] Scheduler tool not loaded - missing dependency: {e}\n"
|
||||
f" To enable scheduled tasks, run:\n"
|
||||
f" pip install croniter>=2.0.0"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Scheduler tool failed to load: {e}")
|
||||
|
||||
# Terminal (basic, should work)
|
||||
try:
|
||||
from agent.tools.terminal.terminal import Terminal
|
||||
tools['Terminal'] = Terminal
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return tools
|
||||
|
||||
# Load optional tools
|
||||
_optional_tools = _import_optional_tools()
|
||||
EnvConfig = _optional_tools.get('EnvConfig')
|
||||
SchedulerTool = _optional_tools.get('SchedulerTool')
|
||||
GoogleSearch = _optional_tools.get('GoogleSearch')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
Terminal = _optional_tools.get('Terminal')
|
||||
@@ -74,28 +76,24 @@ def _import_browser_tool():
|
||||
|
||||
|
||||
# Dynamically set BrowserTool
|
||||
BrowserTool = _import_browser_tool()
|
||||
# BrowserTool = _import_browser_tool()
|
||||
|
||||
# Export all tools (including optional ones that might be None)
|
||||
__all__ = [
|
||||
'BaseTool',
|
||||
'ToolManager',
|
||||
'Calculator',
|
||||
'Read',
|
||||
'Write',
|
||||
'Edit',
|
||||
'Bash',
|
||||
'Grep',
|
||||
'Find',
|
||||
'Ls',
|
||||
'Send',
|
||||
'MemorySearchTool',
|
||||
'MemoryGetTool',
|
||||
'WebFetch',
|
||||
'EnvConfig',
|
||||
'SchedulerTool',
|
||||
# Optional tools (may be None if dependencies not available)
|
||||
'GoogleSearch',
|
||||
'FileSave',
|
||||
'Terminal',
|
||||
'BrowserTool'
|
||||
# 'BrowserTool'
|
||||
]
|
||||
|
||||
"""
|
||||
|
||||
@@ -3,12 +3,14 @@ Bash tool - Execute bash commands
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class Bash(BaseTool):
|
||||
@@ -60,6 +62,12 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
if not command:
|
||||
return ToolResult.fail("Error: command parameter is required")
|
||||
|
||||
# Security check: Prevent accessing sensitive config files
|
||||
if "~/.cow/.env" in command or "~/.cow" in command:
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
# Optional safety check - only warn about extremely dangerous commands
|
||||
if self.safety_mode:
|
||||
warning = self._get_safety_warning(command)
|
||||
@@ -68,7 +76,31 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.")
|
||||
|
||||
try:
|
||||
# Execute command
|
||||
# Prepare environment with .env file variables
|
||||
env = os.environ.copy()
|
||||
|
||||
# Load environment variables from ~/.cow/.env if it exists
|
||||
env_file = os.path.expanduser("~/.cow/.env")
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
env_vars = dotenv_values(env_file)
|
||||
env.update(env_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(env_vars)} variables from {env_file}")
|
||||
except ImportError:
|
||||
logger.debug("[Bash] python-dotenv not installed, skipping .env loading")
|
||||
except Exception as e:
|
||||
logger.debug(f"[Bash] Failed to load .env: {e}")
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"[Bash] CWD: {self.cwd}")
|
||||
logger.debug(f"[Bash] Command: {command[:500]}")
|
||||
logger.debug(f"[Bash] OPENAI_API_KEY in env: {'OPENAI_API_KEY' in env}")
|
||||
logger.debug(f"[Bash] SHELL: {env.get('SHELL', 'not set')}")
|
||||
logger.debug(f"[Bash] Python executable: {sys.executable}")
|
||||
logger.debug(f"[Bash] Process UID: {os.getuid()}")
|
||||
|
||||
# Execute command with inherited environment variables
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
@@ -76,8 +108,50 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
env=env
|
||||
)
|
||||
|
||||
logger.debug(f"[Bash] Exit code: {result.returncode}")
|
||||
logger.debug(f"[Bash] Stdout length: {len(result.stdout)}")
|
||||
logger.debug(f"[Bash] Stderr length: {len(result.stderr)}")
|
||||
|
||||
# Workaround for exit code 126 with no output
|
||||
if result.returncode == 126 and not result.stdout and not result.stderr:
|
||||
logger.warning(f"[Bash] Exit 126 with no output - trying alternative execution method")
|
||||
# Try using argument list instead of shell=True
|
||||
import shlex
|
||||
try:
|
||||
parts = shlex.split(command)
|
||||
if len(parts) > 0:
|
||||
logger.info(f"[Bash] Retrying with argument list: {parts[:3]}...")
|
||||
retry_result = subprocess.run(
|
||||
parts,
|
||||
cwd=self.cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
env=env
|
||||
)
|
||||
logger.debug(f"[Bash] Retry exit code: {retry_result.returncode}, stdout: {len(retry_result.stdout)}, stderr: {len(retry_result.stderr)}")
|
||||
|
||||
# If retry succeeded, use retry result
|
||||
if retry_result.returncode == 0 or retry_result.stdout or retry_result.stderr:
|
||||
result = retry_result
|
||||
else:
|
||||
# Both attempts failed - check if this is openai-image-vision skill
|
||||
if 'openai-image-vision' in command or 'vision.sh' in command:
|
||||
# Create a mock result with helpful error message
|
||||
from types import SimpleNamespace
|
||||
result = SimpleNamespace(
|
||||
returncode=1,
|
||||
stdout='{"error": "图片无法解析", "reason": "该图片格式可能不受支持,或图片文件存在问题", "suggestion": "请尝试其他图片"}',
|
||||
stderr=''
|
||||
)
|
||||
logger.info(f"[Bash] Converted exit 126 to user-friendly image error message for vision skill")
|
||||
except Exception as retry_err:
|
||||
logger.warning(f"[Bash] Retry failed: {retry_err}")
|
||||
|
||||
# Combine stdout and stderr
|
||||
output = result.stdout
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
import math
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class Calculator(BaseTool):
|
||||
name: str = "calculator"
|
||||
description: str = "A tool to perform basic mathematical calculations."
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). "
|
||||
"Ensure your input is a valid Python expression, it will be evaluated directly."
|
||||
}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
config: dict = {}
|
||||
|
||||
def execute(self, args: dict) -> ToolResult:
|
||||
try:
|
||||
# Get the expression
|
||||
expression = args["expression"]
|
||||
|
||||
# Create a safe local environment containing only basic math functions
|
||||
safe_locals = {
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"max": max,
|
||||
"min": min,
|
||||
"pow": pow,
|
||||
"sqrt": math.sqrt,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"pi": math.pi,
|
||||
"e": math.e,
|
||||
"log": math.log,
|
||||
"log10": math.log10,
|
||||
"exp": math.exp,
|
||||
"floor": math.floor,
|
||||
"ceil": math.ceil
|
||||
}
|
||||
|
||||
# Safely evaluate the expression
|
||||
result = eval(expression, {"__builtins__": {}}, safe_locals)
|
||||
|
||||
return ToolResult.success({
|
||||
"result": result,
|
||||
"expression": expression
|
||||
})
|
||||
except Exception as e:
|
||||
return ToolResult.success({
|
||||
"error": str(e),
|
||||
"expression": args.get("expression", "")
|
||||
})
|
||||
@@ -22,7 +22,7 @@ class Edit(BaseTool):
|
||||
"""Tool for precise file editing"""
|
||||
|
||||
name: str = "edit"
|
||||
description: str = "Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits."
|
||||
description: str = "Edit a file by replacing exact text, or append to end if oldText is empty. For append: use empty oldText. For replace: oldText must match exactly (including whitespace)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
@@ -33,7 +33,7 @@ class Edit(BaseTool):
|
||||
},
|
||||
"oldText": {
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (must match exactly)"
|
||||
"description": "Text to find and replace. Use empty string to append to end of file. For replacement: must match exactly including whitespace."
|
||||
},
|
||||
"newText": {
|
||||
"type": "string",
|
||||
@@ -89,34 +89,45 @@ class Edit(BaseTool):
|
||||
normalized_old_text = normalize_to_lf(old_text)
|
||||
normalized_new_text = normalize_to_lf(new_text)
|
||||
|
||||
# Use fuzzy matching to find old text (try exact match first, then fuzzy match)
|
||||
match_result = fuzzy_find_text(normalized_content, normalized_old_text)
|
||||
|
||||
if not match_result.found:
|
||||
return ToolResult.fail(
|
||||
f"Error: Could not find the exact text in {path}. "
|
||||
"The old text must match exactly including all whitespace and newlines."
|
||||
# Special case: empty oldText means append to end of file
|
||||
if not old_text or not old_text.strip():
|
||||
# Append mode: add newText to the end
|
||||
# Add newline before newText if file doesn't end with one
|
||||
if normalized_content and not normalized_content.endswith('\n'):
|
||||
new_content = normalized_content + '\n' + normalized_new_text
|
||||
else:
|
||||
new_content = normalized_content + normalized_new_text
|
||||
base_content = normalized_content # For verification
|
||||
else:
|
||||
# Normal edit mode: find and replace
|
||||
# Use fuzzy matching to find old text (try exact match first, then fuzzy match)
|
||||
match_result = fuzzy_find_text(normalized_content, normalized_old_text)
|
||||
|
||||
if not match_result.found:
|
||||
return ToolResult.fail(
|
||||
f"Error: Could not find the exact text in {path}. "
|
||||
"The old text must match exactly including all whitespace and newlines."
|
||||
)
|
||||
|
||||
# Calculate occurrence count (use fuzzy normalized content for consistency)
|
||||
fuzzy_content = normalize_for_fuzzy_match(normalized_content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
|
||||
occurrences = fuzzy_content.count(fuzzy_old_text)
|
||||
|
||||
if occurrences > 1:
|
||||
return ToolResult.fail(
|
||||
f"Error: Found {occurrences} occurrences of the text in {path}. "
|
||||
"The text must be unique. Please provide more context to make it unique."
|
||||
)
|
||||
|
||||
# Execute replacement (use matched text position)
|
||||
base_content = match_result.content_for_replacement
|
||||
new_content = (
|
||||
base_content[:match_result.index] +
|
||||
normalized_new_text +
|
||||
base_content[match_result.index + match_result.match_length:]
|
||||
)
|
||||
|
||||
# Calculate occurrence count (use fuzzy normalized content for consistency)
|
||||
fuzzy_content = normalize_for_fuzzy_match(normalized_content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
|
||||
occurrences = fuzzy_content.count(fuzzy_old_text)
|
||||
|
||||
if occurrences > 1:
|
||||
return ToolResult.fail(
|
||||
f"Error: Found {occurrences} occurrences of the text in {path}. "
|
||||
"The text must be unique. Please provide more context to make it unique."
|
||||
)
|
||||
|
||||
# Execute replacement (use matched text position)
|
||||
base_content = match_result.content_for_replacement
|
||||
new_content = (
|
||||
base_content[:match_result.index] +
|
||||
normalized_new_text +
|
||||
base_content[match_result.index + match_result.match_length:]
|
||||
)
|
||||
|
||||
# Verify replacement actually changed content
|
||||
if base_content == new_content:
|
||||
return ToolResult.fail(
|
||||
|
||||
3
agent/tools/env_config/__init__.py
Normal file
3
agent/tools/env_config/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.env_config.env_config import EnvConfig
|
||||
|
||||
__all__ = ['EnvConfig']
|
||||
284
agent/tools/env_config/env_config.py
Normal file
284
agent/tools/env_config/env_config.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Environment Configuration Tool - Manage API keys and environment variables
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# API Key 知识库:常见的环境变量及其描述
|
||||
API_KEY_REGISTRY = {
|
||||
# AI 模型服务
|
||||
"OPENAI_API_KEY": "OpenAI API 密钥 (用于GPT模型、Embedding模型)",
|
||||
"GEMINI_API_KEY": "Google Gemini API 密钥",
|
||||
"CLAUDE_API_KEY": "Claude API 密钥 (用于Claude模型)",
|
||||
"LINKAI_API_KEY": "LinkAI智能体平台 API 密钥,支持多种模型切换",
|
||||
# 搜索服务
|
||||
"BOCHA_API_KEY": "博查 AI 搜索 API 密钥 ",
|
||||
}
|
||||
|
||||
class EnvConfig(BaseTool):
|
||||
"""Tool for managing environment variables (API keys, etc.)"""
|
||||
|
||||
name: str = "env_config"
|
||||
description: str = (
|
||||
"Manage API keys and skill configurations securely. "
|
||||
"Use this tool when user wants to configure API keys (like BOCHA_API_KEY, OPENAI_API_KEY), "
|
||||
"view configured keys, or manage skill settings. "
|
||||
"Actions: 'set' (add/update key), 'get' (view specific key), 'list' (show all configured keys), 'delete' (remove key). "
|
||||
"Values are automatically masked for security. Changes take effect immediately via hot reload."
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Action to perform: 'set', 'get', 'list', 'delete'",
|
||||
"enum": ["set", "get", "list", "delete"]
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Environment variable key name. Common keys:\n"
|
||||
"- OPENAI_API_KEY: OpenAI API (GPT models)\n"
|
||||
"- OPENAI_API_BASE: OpenAI API base URL\n"
|
||||
"- CLAUDE_API_KEY: Anthropic Claude API\n"
|
||||
"- GEMINI_API_KEY: Google Gemini API\n"
|
||||
"- LINKAI_API_KEY: LinkAI platform\n"
|
||||
"- BOCHA_API_KEY: Bocha AI search (博查搜索)\n"
|
||||
"Use exact key names (case-sensitive, all uppercase with underscores)"
|
||||
)
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Value to set for the environment variable (for 'set' action)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
# Store env config in ~/.cow directory (outside workspace for security)
|
||||
self.env_dir = os.path.expanduser("~/.cow")
|
||||
self.env_path = os.path.join(self.env_dir, '.env')
|
||||
self.agent_bridge = self.config.get("agent_bridge") # Reference to AgentBridge for hot reload
|
||||
# Don't create .env file in __init__ to avoid issues during tool discovery
|
||||
# It will be created on first use in execute()
|
||||
|
||||
def _ensure_env_file(self):
|
||||
"""Ensure the .env file exists"""
|
||||
# Create ~/.cow directory if it doesn't exist
|
||||
os.makedirs(self.env_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(self.env_path):
|
||||
Path(self.env_path).touch()
|
||||
logger.info(f"[EnvConfig] Created .env file at {self.env_path}")
|
||||
|
||||
def _mask_value(self, value: str) -> str:
|
||||
"""Mask sensitive parts of a value for logging"""
|
||||
if not value or len(value) <= 10:
|
||||
return "***"
|
||||
return f"{value[:6]}***{value[-4:]}"
|
||||
|
||||
def _read_env_file(self) -> Dict[str, str]:
|
||||
"""Read all key-value pairs from .env file"""
|
||||
env_vars = {}
|
||||
if os.path.exists(self.env_path):
|
||||
with open(self.env_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
# Parse KEY=VALUE
|
||||
match = re.match(r'^([^=]+)=(.*)$', line)
|
||||
if match:
|
||||
key, value = match.groups()
|
||||
env_vars[key.strip()] = value.strip()
|
||||
return env_vars
|
||||
|
||||
def _write_env_file(self, env_vars: Dict[str, str]):
|
||||
"""Write all key-value pairs to .env file"""
|
||||
with open(self.env_path, 'w', encoding='utf-8') as f:
|
||||
f.write("# Environment variables for agent skills\n")
|
||||
f.write("# Auto-managed by env_config tool\n\n")
|
||||
for key, value in sorted(env_vars.items()):
|
||||
f.write(f"{key}={value}\n")
|
||||
|
||||
def _reload_env(self):
|
||||
"""Reload environment variables from .env file"""
|
||||
env_vars = self._read_env_file()
|
||||
for key, value in env_vars.items():
|
||||
os.environ[key] = value
|
||||
logger.debug(f"[EnvConfig] Reloaded {len(env_vars)} environment variables")
|
||||
|
||||
def _refresh_skills(self):
|
||||
"""Refresh skills after environment variable changes"""
|
||||
if self.agent_bridge:
|
||||
try:
|
||||
# Reload .env file
|
||||
self._reload_env()
|
||||
|
||||
# Refresh skills in all agent instances
|
||||
refreshed = self.agent_bridge.refresh_all_skills()
|
||||
logger.info(f"[EnvConfig] Refreshed skills in {refreshed} agent instance(s)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[EnvConfig] Failed to refresh skills: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute environment configuration operation
|
||||
|
||||
:param args: Contains action, key, and value parameters
|
||||
:return: Result of the operation
|
||||
"""
|
||||
# Ensure .env file exists on first use
|
||||
self._ensure_env_file()
|
||||
|
||||
action = args.get("action")
|
||||
key = args.get("key")
|
||||
value = args.get("value")
|
||||
|
||||
try:
|
||||
if action == "set":
|
||||
if not key or not value:
|
||||
return ToolResult.fail("Error: 'key' and 'value' are required for 'set' action.")
|
||||
|
||||
# Read current env vars
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
# Update the key
|
||||
env_vars[key] = value
|
||||
|
||||
# Write back to file
|
||||
self._write_env_file(env_vars)
|
||||
|
||||
# Update current process env
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[EnvConfig] Set {key}={self._mask_value(value)}")
|
||||
|
||||
# Try to refresh skills immediately
|
||||
refreshed = self._refresh_skills()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully set {key}",
|
||||
"key": key,
|
||||
"value": self._mask_value(value),
|
||||
}
|
||||
|
||||
if refreshed:
|
||||
result["note"] = "✅ Skills refreshed automatically - changes are now active"
|
||||
else:
|
||||
result["note"] = "⚠️ Skills not refreshed - restart agent to load new skills"
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
elif action == "get":
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for 'get' action.")
|
||||
|
||||
# Check in file first, then in current env
|
||||
env_vars = self._read_env_file()
|
||||
value = env_vars.get(key) or os.getenv(key)
|
||||
|
||||
# Get description from registry
|
||||
description = API_KEY_REGISTRY.get(key, "未知用途的环境变量")
|
||||
|
||||
if value is not None:
|
||||
logger.info(f"[EnvConfig] Got {key}={self._mask_value(value)}")
|
||||
return ToolResult.success({
|
||||
"key": key,
|
||||
"value": self._mask_value(value),
|
||||
"description": description,
|
||||
"exists": True
|
||||
})
|
||||
else:
|
||||
return ToolResult.success({
|
||||
"key": key,
|
||||
"description": description,
|
||||
"exists": False,
|
||||
"message": f"Environment variable '{key}' is not set"
|
||||
})
|
||||
|
||||
elif action == "list":
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
# Build detailed variable list with descriptions
|
||||
variables_with_info = {}
|
||||
for key, value in env_vars.items():
|
||||
variables_with_info[key] = {
|
||||
"value": self._mask_value(value),
|
||||
"description": API_KEY_REGISTRY.get(key, "未知用途的环境变量")
|
||||
}
|
||||
|
||||
logger.info(f"[EnvConfig] Listed {len(env_vars)} environment variables")
|
||||
|
||||
if not env_vars:
|
||||
return ToolResult.success({
|
||||
"message": "No environment variables configured",
|
||||
"variables": {},
|
||||
"note": "常用的 API 密钥可以通过 env_config(action='set', key='KEY_NAME', value='your-key') 来配置"
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"message": f"Found {len(env_vars)} environment variable(s)",
|
||||
"variables": variables_with_info
|
||||
})
|
||||
|
||||
elif action == "delete":
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for 'delete' action.")
|
||||
|
||||
# Read current env vars
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
if key not in env_vars:
|
||||
return ToolResult.success({
|
||||
"message": f"Environment variable '{key}' was not set",
|
||||
"key": key
|
||||
})
|
||||
|
||||
# Remove the key
|
||||
del env_vars[key]
|
||||
|
||||
# Write back to file
|
||||
self._write_env_file(env_vars)
|
||||
|
||||
# Remove from current process env
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
logger.info(f"[EnvConfig] Deleted {key}")
|
||||
|
||||
# Try to refresh skills immediately
|
||||
refreshed = self._refresh_skills()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully deleted {key}",
|
||||
"key": key,
|
||||
}
|
||||
|
||||
if refreshed:
|
||||
result["note"] = "✅ Skills refreshed automatically - changes are now active"
|
||||
else:
|
||||
result["note"] = "⚠️ Skills not refreshed - restart agent to apply changes"
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
else:
|
||||
return ToolResult.fail(f"Error: Unknown action '{action}'. Use 'set', 'get', 'list', or 'delete'.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EnvConfig] Error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"EnvConfig tool error: {str(e)}")
|
||||
@@ -1,3 +0,0 @@
|
||||
from .find import Find
|
||||
|
||||
__all__ = ['Find']
|
||||
@@ -1,177 +0,0 @@
|
||||
"""
|
||||
Find tool - Search for files by glob pattern
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob as glob_module
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 1000
|
||||
|
||||
|
||||
class Find(BaseTool):
|
||||
"""Tool for finding files by pattern"""
|
||||
|
||||
name: str = "find"
|
||||
description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to search in (default: current directory)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of results (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file search
|
||||
|
||||
:param args: Search parameters
|
||||
:return: Search results or error
|
||||
"""
|
||||
pattern = args.get("pattern", "").strip()
|
||||
search_path = args.get("path", ".").strip()
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
if not pattern:
|
||||
return ToolResult.fail("Error: pattern parameter is required")
|
||||
|
||||
# Resolve search path
|
||||
absolute_path = self._resolve_path(search_path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||
|
||||
if not os.path.isdir(absolute_path):
|
||||
return ToolResult.fail(f"Error: Not a directory: {search_path}")
|
||||
|
||||
try:
|
||||
# Load .gitignore patterns
|
||||
ignore_patterns = self._load_gitignore(absolute_path)
|
||||
|
||||
# Search for files
|
||||
results = []
|
||||
search_pattern = os.path.join(absolute_path, pattern)
|
||||
|
||||
# Use glob with recursive support
|
||||
for file_path in glob_module.glob(search_pattern, recursive=True):
|
||||
# Skip if matches ignore patterns
|
||||
if self._should_ignore(file_path, absolute_path, ignore_patterns):
|
||||
continue
|
||||
|
||||
# Get relative path
|
||||
relative_path = os.path.relpath(file_path, absolute_path)
|
||||
|
||||
# Add trailing slash for directories
|
||||
if os.path.isdir(file_path):
|
||||
relative_path += '/'
|
||||
|
||||
results.append(relative_path)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
if not results:
|
||||
return ToolResult.success({"message": "No files found matching pattern", "files": []})
|
||||
|
||||
# Sort results
|
||||
results.sort()
|
||||
|
||||
# Format output
|
||||
raw_output = '\n'.join(results)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
result_limit_reached = len(results) >= limit
|
||||
if result_limit_reached:
|
||||
notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||
details["result_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"file_count": len(results),
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing find: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _load_gitignore(self, directory: str) -> List[str]:
|
||||
"""Load .gitignore patterns from directory"""
|
||||
patterns = []
|
||||
gitignore_path = os.path.join(directory, '.gitignore')
|
||||
|
||||
if os.path.exists(gitignore_path):
|
||||
try:
|
||||
with open(gitignore_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
patterns.append(line)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Add common ignore patterns
|
||||
patterns.extend([
|
||||
'.git',
|
||||
'__pycache__',
|
||||
'*.pyc',
|
||||
'node_modules',
|
||||
'.DS_Store'
|
||||
])
|
||||
|
||||
return patterns
|
||||
|
||||
def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool:
|
||||
"""Check if file should be ignored based on patterns"""
|
||||
relative_path = os.path.relpath(file_path, base_path)
|
||||
|
||||
for pattern in patterns:
|
||||
# Simple pattern matching
|
||||
if pattern in relative_path:
|
||||
return True
|
||||
|
||||
# Check if it's a directory pattern
|
||||
if pattern.endswith('/'):
|
||||
if relative_path.startswith(pattern.rstrip('/')):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -1,3 +0,0 @@
|
||||
from .grep import Grep
|
||||
|
||||
__all__ = ['Grep']
|
||||
@@ -1,248 +0,0 @@
|
||||
"""
|
||||
Grep tool - Search file contents for patterns
|
||||
Uses ripgrep (rg) for fast searching
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import (
|
||||
truncate_head, truncate_line, format_size,
|
||||
DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 100
|
||||
|
||||
|
||||
class Grep(BaseTool):
|
||||
"""Tool for searching file contents"""
|
||||
|
||||
name: str = "grep"
|
||||
description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Search pattern (regex or literal string)"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search (default: current directory)"
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'"
|
||||
},
|
||||
"ignoreCase": {
|
||||
"type": "boolean",
|
||||
"description": "Case-insensitive search (default: false)"
|
||||
},
|
||||
"literal": {
|
||||
"type": "boolean",
|
||||
"description": "Treat pattern as literal string instead of regex (default: false)"
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to show before and after each match (default: 0)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.rg_path = self._find_ripgrep()
|
||||
|
||||
def _find_ripgrep(self) -> Optional[str]:
|
||||
"""Find ripgrep executable"""
|
||||
try:
|
||||
result = subprocess.run(['which', 'rg'], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute grep search
|
||||
|
||||
:param args: Search parameters
|
||||
:return: Search results or error
|
||||
"""
|
||||
if not self.rg_path:
|
||||
return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.")
|
||||
|
||||
pattern = args.get("pattern", "").strip()
|
||||
search_path = args.get("path", ".").strip()
|
||||
glob = args.get("glob")
|
||||
ignore_case = args.get("ignoreCase", False)
|
||||
literal = args.get("literal", False)
|
||||
context = args.get("context", 0)
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
if not pattern:
|
||||
return ToolResult.fail("Error: pattern parameter is required")
|
||||
|
||||
# Resolve search path
|
||||
absolute_path = self._resolve_path(search_path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||
|
||||
# Build ripgrep command
|
||||
cmd = [
|
||||
self.rg_path,
|
||||
'--json',
|
||||
'--line-number',
|
||||
'--color=never',
|
||||
'--hidden'
|
||||
]
|
||||
|
||||
if ignore_case:
|
||||
cmd.append('--ignore-case')
|
||||
|
||||
if literal:
|
||||
cmd.append('--fixed-strings')
|
||||
|
||||
if glob:
|
||||
cmd.extend(['--glob', glob])
|
||||
|
||||
cmd.extend([pattern, absolute_path])
|
||||
|
||||
try:
|
||||
# Execute ripgrep
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=self.cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Parse JSON output
|
||||
matches = []
|
||||
match_count = 0
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
event = json.loads(line)
|
||||
if event.get('type') == 'match':
|
||||
data = event.get('data', {})
|
||||
file_path = data.get('path', {}).get('text')
|
||||
line_number = data.get('line_number')
|
||||
|
||||
if file_path and line_number:
|
||||
matches.append({
|
||||
'file': file_path,
|
||||
'line': line_number
|
||||
})
|
||||
match_count += 1
|
||||
|
||||
if match_count >= limit:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if match_count == 0:
|
||||
return ToolResult.success({"message": "No matches found", "matches": []})
|
||||
|
||||
# Format output with context
|
||||
output_lines = []
|
||||
lines_truncated = False
|
||||
is_directory = os.path.isdir(absolute_path)
|
||||
|
||||
for match in matches:
|
||||
file_path = match['file']
|
||||
line_number = match['line']
|
||||
|
||||
# Format file path
|
||||
if is_directory:
|
||||
relative_path = os.path.relpath(file_path, absolute_path)
|
||||
else:
|
||||
relative_path = os.path.basename(file_path)
|
||||
|
||||
# Read file and get context
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
file_lines = f.read().split('\n')
|
||||
|
||||
# Calculate context range
|
||||
start = max(0, line_number - 1 - context) if context > 0 else line_number - 1
|
||||
end = min(len(file_lines), line_number + context) if context > 0 else line_number
|
||||
|
||||
# Format lines with context
|
||||
for i in range(start, end):
|
||||
line_text = file_lines[i].replace('\r', '')
|
||||
|
||||
# Truncate long lines
|
||||
truncated_text, was_truncated = truncate_line(line_text)
|
||||
if was_truncated:
|
||||
lines_truncated = True
|
||||
|
||||
# Format output
|
||||
current_line = i + 1
|
||||
if current_line == line_number:
|
||||
output_lines.append(f"{relative_path}:{current_line}: {truncated_text}")
|
||||
else:
|
||||
output_lines.append(f"{relative_path}-{current_line}- {truncated_text}")
|
||||
|
||||
except Exception:
|
||||
output_lines.append(f"{relative_path}:{line_number}: (unable to read file)")
|
||||
|
||||
# Apply byte truncation
|
||||
raw_output = '\n'.join(output_lines)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
if match_count >= limit:
|
||||
notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||
details["match_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if lines_truncated:
|
||||
notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines")
|
||||
details["lines_truncated"] = True
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"match_count": match_count,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResult.fail("Error: Search timed out after 30 seconds")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing grep: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
@@ -50,6 +50,13 @@ class Ls(BaseTool):
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent accessing sensitive config directory
|
||||
env_config_dir = os.path.expanduser("~/.cow")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_dir):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
# Provide helpful hint if using relative path
|
||||
if not os.path.isabs(path) and not path.startswith('~'):
|
||||
|
||||
@@ -4,8 +4,6 @@ Memory get tool
|
||||
Allows agents to read specific sections from memory files
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
@@ -22,7 +20,7 @@ class MemoryGetTool(BaseTool):
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')"
|
||||
"description": "Relative path to the memory file (e.g. 'MEMORY.md', 'memory/2026-01-01.md')"
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
@@ -70,7 +68,8 @@ class MemoryGetTool(BaseTool):
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
|
||||
# Auto-prepend memory/ if not present and not absolute path
|
||||
if not path.startswith('memory/') and not path.startswith('/'):
|
||||
# Exception: MEMORY.md is in the root directory
|
||||
if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
path = f'memory/{path}'
|
||||
|
||||
file_path = workspace_dir / path
|
||||
|
||||
@@ -15,7 +15,7 @@ class Read(BaseTool):
|
||||
"""Tool for reading file contents"""
|
||||
|
||||
name: str = "read"
|
||||
description: str = f"Read the contents of a file. Supports text files, PDF files, and images (jpg, png, gif, webp). For text files, output is truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Use offset/limit for large files."
|
||||
description: str = f"Read or inspect file contents. For text/PDF files, returns content (truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB). For images/videos/audio, returns metadata only (file info, size, type). Use offset/limit for large text files."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
@@ -26,7 +26,7 @@ class Read(BaseTool):
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, optional)"
|
||||
"description": "Line number to start reading from (1-indexed, optional). Use negative values to read from end (e.g. -20 for last 20 lines)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
@@ -39,10 +39,25 @@ class Read(BaseTool):
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
# Supported image formats
|
||||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp'}
|
||||
# Supported PDF format
|
||||
|
||||
# File type categories
|
||||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
|
||||
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
|
||||
self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
|
||||
self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'}
|
||||
self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'}
|
||||
self.pdf_extensions = {'.pdf'}
|
||||
|
||||
# Readable text formats (will be read with truncation)
|
||||
self.text_extensions = {
|
||||
'.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml',
|
||||
'.py', '.js', '.ts', '.java', '.c', '.cpp', '.h', '.hpp', '.go', '.rs', '.rb', '.php',
|
||||
'.html', '.css', '.scss', '.sass', '.less', '.vue', '.jsx', '.tsx',
|
||||
'.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd',
|
||||
'.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex',
|
||||
'.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg',
|
||||
'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx' # Office documents
|
||||
}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
@@ -61,6 +76,13 @@ class Read(BaseTool):
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent reading sensitive config files
|
||||
env_config_path = os.path.expanduser("~/.cow/.env")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_path):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
# Provide helpful hint if using relative path
|
||||
@@ -78,16 +100,25 @@ class Read(BaseTool):
|
||||
|
||||
# Check file type
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
|
||||
# Check if image
|
||||
# Check if image - return metadata for sending
|
||||
if file_ext in self.image_extensions:
|
||||
return self._read_image(absolute_path, file_ext)
|
||||
|
||||
# Check if video/audio/binary/archive - return metadata only
|
||||
if file_ext in self.video_extensions:
|
||||
return self._return_file_metadata(absolute_path, "video", file_size)
|
||||
if file_ext in self.audio_extensions:
|
||||
return self._return_file_metadata(absolute_path, "audio", file_size)
|
||||
if file_ext in self.binary_extensions or file_ext in self.archive_extensions:
|
||||
return self._return_file_metadata(absolute_path, "binary", file_size)
|
||||
|
||||
# Check if PDF
|
||||
if file_ext in self.pdf_extensions:
|
||||
return self._read_pdf(absolute_path, path, offset, limit)
|
||||
|
||||
# Read text file
|
||||
# Read text file (with truncation for large files)
|
||||
return self._read_text(absolute_path, path, offset, limit)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
@@ -103,25 +134,56 @@ class Read(BaseTool):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _return_file_metadata(self, absolute_path: str, file_type: str, file_size: int) -> ToolResult:
|
||||
"""
|
||||
Return file metadata for non-readable files (video, audio, binary, etc.)
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param file_type: Type of file (video, audio, binary, etc.)
|
||||
:param file_size: File size in bytes
|
||||
:return: File metadata
|
||||
"""
|
||||
file_name = Path(absolute_path).name
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
|
||||
# Determine MIME type
|
||||
mime_types = {
|
||||
# Video
|
||||
'.mp4': 'video/mp4', '.avi': 'video/x-msvideo', '.mov': 'video/quicktime',
|
||||
'.mkv': 'video/x-matroska', '.webm': 'video/webm',
|
||||
# Audio
|
||||
'.mp3': 'audio/mpeg', '.wav': 'audio/wav', '.ogg': 'audio/ogg',
|
||||
'.m4a': 'audio/mp4', '.flac': 'audio/flac',
|
||||
# Binary
|
||||
'.zip': 'application/zip', '.tar': 'application/x-tar',
|
||||
'.gz': 'application/gzip', '.rar': 'application/x-rar-compressed',
|
||||
}
|
||||
mime_type = mime_types.get(file_ext, 'application/octet-stream')
|
||||
|
||||
result = {
|
||||
"type": f"{file_type}_metadata",
|
||||
"file_type": file_type,
|
||||
"path": absolute_path,
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"message": f"{file_type.capitalize()} 文件: {file_name} ({format_size(file_size)})\n提示: 如果需要发送此文件,请使用 send 工具。"
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult:
|
||||
"""
|
||||
Read image file
|
||||
Read image file - always return metadata only (images should be sent, not read into context)
|
||||
|
||||
:param absolute_path: Absolute path to the image file
|
||||
:param file_ext: File extension
|
||||
:return: Result containing image information
|
||||
:return: Result containing image metadata for sending
|
||||
"""
|
||||
try:
|
||||
# Read image file
|
||||
with open(absolute_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
# Get file size
|
||||
file_size = len(image_data)
|
||||
|
||||
# Return image information (actual image data can be base64 encoded when needed)
|
||||
import base64
|
||||
base64_data = base64.b64encode(image_data).decode('utf-8')
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
|
||||
# Determine MIME type
|
||||
mime_type_map = {
|
||||
@@ -133,12 +195,15 @@ class Read(BaseTool):
|
||||
}
|
||||
mime_type = mime_type_map.get(file_ext, 'image/jpeg')
|
||||
|
||||
# Return metadata for images (NOT file_to_send - use send tool to actually send)
|
||||
result = {
|
||||
"type": "image",
|
||||
"type": "image_metadata",
|
||||
"file_type": "image",
|
||||
"path": absolute_path,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"data": base64_data # Base64 encoded image data
|
||||
"message": f"图片文件: {Path(absolute_path).name} ({format_size(file_size)})\n提示: 如果需要发送此图片,请使用 send 工具。"
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
@@ -157,21 +222,49 @@ class Read(BaseTool):
|
||||
:return: File content or error message
|
||||
"""
|
||||
try:
|
||||
# Check file size first
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
# File too large, return metadata only
|
||||
return ToolResult.success({
|
||||
"type": "file_to_send",
|
||||
"file_type": "document",
|
||||
"path": absolute_path,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}"
|
||||
})
|
||||
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Truncate content if too long (20K characters max for model context)
|
||||
MAX_CONTENT_CHARS = 20 * 1024 # 20K characters
|
||||
content_truncated = False
|
||||
if len(content) > MAX_CONTENT_CHARS:
|
||||
content = content[:MAX_CONTENT_CHARS]
|
||||
content_truncated = True
|
||||
|
||||
all_lines = content.split('\n')
|
||||
total_file_lines = len(all_lines)
|
||||
|
||||
# Apply offset (if specified)
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
start_line = max(0, offset - 1) # Convert to 0-indexed
|
||||
if start_line >= total_file_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
|
||||
)
|
||||
if offset < 0:
|
||||
# Negative offset: read from end
|
||||
# -20 means "last 20 lines" → start from (total - 20)
|
||||
start_line = max(0, total_file_lines + offset)
|
||||
else:
|
||||
# Positive offset: read from start (1-indexed)
|
||||
start_line = max(0, offset - 1) # Convert to 0-indexed
|
||||
if start_line >= total_file_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
|
||||
)
|
||||
|
||||
start_line_display = start_line + 1 # For display (1-indexed)
|
||||
|
||||
@@ -191,6 +284,10 @@ class Read(BaseTool):
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
# Add truncation warning if content was truncated
|
||||
if content_truncated:
|
||||
output_text = f"[文件内容已截断到前 {format_size(MAX_CONTENT_CHARS)},完整文件大小: {format_size(file_size)}]\n\n"
|
||||
|
||||
if truncation.first_line_exceeds_limit:
|
||||
# First line exceeds 30KB limit
|
||||
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
|
||||
|
||||
287
agent/tools/scheduler/README.md
Normal file
287
agent/tools/scheduler/README.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# 定时任务工具 (Scheduler Tool)
|
||||
|
||||
## 功能简介
|
||||
|
||||
定时任务工具允许 Agent 创建、管理和执行定时任务,支持:
|
||||
|
||||
- ⏰ **定时提醒**: 在指定时间发送消息
|
||||
- 🔄 **周期性任务**: 按固定间隔或 cron 表达式重复执行
|
||||
- 🔧 **动态工具调用**: 定时执行其他工具并发送结果(如搜索新闻、查询天气等)
|
||||
- 📋 **任务管理**: 查询、启用、禁用、删除任务
|
||||
|
||||
## 安装依赖
|
||||
|
||||
```bash
|
||||
pip install croniter>=2.0.0
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 创建定时任务
|
||||
|
||||
Agent 可以通过自然语言创建定时任务,支持两种类型:
|
||||
|
||||
#### 1.1 静态消息任务
|
||||
|
||||
发送预定义的消息:
|
||||
|
||||
**示例对话:**
|
||||
```
|
||||
用户: 每天早上9点提醒我开会
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: create
|
||||
name: 每日开会提醒
|
||||
message: 该开会了!
|
||||
schedule_type: cron
|
||||
schedule_value: 0 9 * * *
|
||||
```
|
||||
|
||||
#### 1.2 动态工具调用任务
|
||||
|
||||
定时执行工具并发送结果:
|
||||
|
||||
**示例对话:**
|
||||
```
|
||||
用户: 每天早上8点帮我读取一下今日日程
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: create
|
||||
name: 每日日程
|
||||
tool_call:
|
||||
tool_name: read
|
||||
tool_params:
|
||||
file_path: ~/cow/schedule.txt
|
||||
result_prefix: 📅 今日日程
|
||||
schedule_type: cron
|
||||
schedule_value: 0 8 * * *
|
||||
```
|
||||
|
||||
**工具调用参数说明:**
|
||||
- `tool_name`: 要调用的工具名称(如 `bash`、`read`、`write` 等内置工具)
|
||||
- `tool_params`: 工具的参数(字典格式)
|
||||
- `result_prefix`: 可选,在结果前添加的前缀文本
|
||||
|
||||
**注意:** 如果要使用 skills(如 bocha-search),需要通过 `bash` 工具调用 skill 脚本
|
||||
|
||||
### 2. 支持的调度类型
|
||||
|
||||
#### Cron 表达式 (`cron`)
|
||||
使用标准 cron 表达式:
|
||||
|
||||
```
|
||||
0 9 * * * # 每天 9:00
|
||||
0 */2 * * * # 每 2 小时
|
||||
30 8 * * 1-5 # 工作日 8:30
|
||||
0 0 1 * * # 每月 1 号
|
||||
```
|
||||
|
||||
#### 固定间隔 (`interval`)
|
||||
以秒为单位的间隔:
|
||||
|
||||
```
|
||||
3600 # 每小时
|
||||
86400 # 每天
|
||||
1800 # 每 30 分钟
|
||||
```
|
||||
|
||||
#### 一次性任务 (`once`)
|
||||
指定具体时间(ISO 格式):
|
||||
|
||||
```
|
||||
2024-12-25T09:00:00
|
||||
2024-12-31T23:59:59
|
||||
```
|
||||
|
||||
### 3. 查询任务列表
|
||||
|
||||
```
|
||||
用户: 查看我的定时任务
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: list
|
||||
```
|
||||
|
||||
### 4. 查看任务详情
|
||||
|
||||
```
|
||||
用户: 查看任务 abc123 的详情
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: get
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
### 5. 删除任务
|
||||
|
||||
```
|
||||
用户: 删除任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: delete
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
### 6. 启用/禁用任务
|
||||
|
||||
```
|
||||
用户: 暂停任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: disable
|
||||
task_id: abc123
|
||||
|
||||
用户: 恢复任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: enable
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
## 任务存储
|
||||
|
||||
任务保存在 JSON 文件中:
|
||||
```
|
||||
~/cow/scheduler/tasks.json
|
||||
```
|
||||
|
||||
任务数据结构:
|
||||
|
||||
**静态消息任务:**
|
||||
```json
|
||||
{
|
||||
"id": "abc123",
|
||||
"name": "每日提醒",
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-01T10:00:00",
|
||||
"updated_at": "2024-01-01T10:00:00",
|
||||
"schedule": {
|
||||
"type": "cron",
|
||||
"expression": "0 9 * * *"
|
||||
},
|
||||
"action": {
|
||||
"type": "send_message",
|
||||
"content": "该开会了!",
|
||||
"receiver": "wxid_xxx",
|
||||
"receiver_name": "张三",
|
||||
"is_group": false,
|
||||
"channel_type": "wechat"
|
||||
},
|
||||
"next_run_at": "2024-01-02T09:00:00",
|
||||
"last_run_at": "2024-01-01T09:00:00"
|
||||
}
|
||||
```
|
||||
|
||||
**动态工具调用任务:**
|
||||
```json
|
||||
{
|
||||
"id": "def456",
|
||||
"name": "每日日程",
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-01T10:00:00",
|
||||
"updated_at": "2024-01-01T10:00:00",
|
||||
"schedule": {
|
||||
"type": "cron",
|
||||
"expression": "0 8 * * *"
|
||||
},
|
||||
"action": {
|
||||
"type": "tool_call",
|
||||
"tool_name": "read",
|
||||
"tool_params": {
|
||||
"file_path": "~/cow/schedule.txt"
|
||||
},
|
||||
"result_prefix": "📅 今日日程",
|
||||
"receiver": "wxid_xxx",
|
||||
"receiver_name": "张三",
|
||||
"is_group": false,
|
||||
"channel_type": "wechat"
|
||||
},
|
||||
"next_run_at": "2024-01-02T08:00:00"
|
||||
}
|
||||
```
|
||||
|
||||
## 后台服务
|
||||
|
||||
定时任务由后台服务 `SchedulerService` 管理:
|
||||
|
||||
- 每 30 秒检查一次到期任务
|
||||
- 自动执行到期任务
|
||||
- 计算下次执行时间
|
||||
- 记录执行历史和错误
|
||||
|
||||
服务在 Agent 初始化时自动启动,无需手动配置。
|
||||
|
||||
## 接收者确定
|
||||
|
||||
定时任务会发送给**创建任务时的对话对象**:
|
||||
|
||||
- 如果在私聊中创建,发送给该用户
|
||||
- 如果在群聊中创建,发送到该群
|
||||
- 接收者信息在创建时自动保存
|
||||
|
||||
## 常见用例
|
||||
|
||||
### 1. 每日提醒(静态消息)
|
||||
```
|
||||
用户: 每天早上8点提醒我吃药
|
||||
Agent: ✅ 定时任务创建成功
|
||||
任务ID: a1b2c3d4
|
||||
调度: 每天 8:00
|
||||
消息: 该吃药了!
|
||||
```
|
||||
|
||||
### 2. 工作日提醒(静态消息)
|
||||
```
|
||||
用户: 工作日下午6点提醒我下班
|
||||
Agent: [创建 cron: 0 18 * * 1-5]
|
||||
消息: 该下班了!
|
||||
```
|
||||
|
||||
### 3. 倒计时提醒(静态消息)
|
||||
```
|
||||
用户: 1小时后提醒我
|
||||
Agent: [创建 interval: 3600]
|
||||
```
|
||||
|
||||
### 4. 每日日程推送(动态工具调用)
|
||||
```
|
||||
用户: 每天早上8点帮我读取今日日程
|
||||
Agent: ✅ 定时任务创建成功
|
||||
任务ID: schedule001
|
||||
调度: 每天 8:00
|
||||
工具: read(file_path='~/cow/schedule.txt')
|
||||
前缀: 📅 今日日程
|
||||
```
|
||||
|
||||
### 5. 定时文件备份(动态工具调用)
|
||||
```
|
||||
用户: 每天晚上11点备份工作文件
|
||||
Agent: [创建 cron: 0 23 * * *]
|
||||
工具: bash(command='cp ~/cow/work.txt ~/cow/backup/work_$(date +%Y%m%d).txt')
|
||||
前缀: ✅ 文件已备份
|
||||
```
|
||||
|
||||
### 6. 周报提醒(静态消息)
|
||||
```
|
||||
用户: 每周五下午5点提醒我写周报
|
||||
Agent: [创建 cron: 0 17 * * 5]
|
||||
消息: 📊 该写周报了!
|
||||
```
|
||||
|
||||
### 4. 特定日期提醒
|
||||
```
|
||||
用户: 12月25日早上9点提醒我圣诞快乐
|
||||
Agent: [创建 once: 2024-12-25T09:00:00]
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **时区**: 使用系统本地时区
|
||||
2. **精度**: 检查间隔为 30 秒,实际执行可能有 ±30 秒误差
|
||||
3. **持久化**: 任务保存在文件中,重启后自动恢复
|
||||
4. **一次性任务**: 执行后自动禁用,不会删除(可手动删除)
|
||||
5. **错误处理**: 执行失败会记录错误,不影响其他任务
|
||||
|
||||
## 技术实现
|
||||
|
||||
- **TaskStore**: 任务持久化存储
|
||||
- **SchedulerService**: 后台调度服务
|
||||
- **SchedulerTool**: Agent 工具接口
|
||||
- **Integration**: 与 AgentBridge 集成
|
||||
|
||||
## 依赖
|
||||
|
||||
- `croniter`: Cron 表达式解析(轻量级,仅 ~50KB)
|
||||
7
agent/tools/scheduler/__init__.py
Normal file
7
agent/tools/scheduler/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Scheduler tool for managing scheduled tasks
|
||||
"""
|
||||
|
||||
from .scheduler_tool import SchedulerTool
|
||||
|
||||
__all__ = ["SchedulerTool"]
|
||||
447
agent/tools/scheduler/integration.py
Normal file
447
agent/tools/scheduler/integration.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
Integration module for scheduler with AgentBridge
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
# Global scheduler service instance
|
||||
_scheduler_service = None
|
||||
_task_store = None
|
||||
|
||||
|
||||
def init_scheduler(agent_bridge) -> bool:
|
||||
"""
|
||||
Initialize scheduler service
|
||||
|
||||
Args:
|
||||
agent_bridge: AgentBridge instance
|
||||
|
||||
Returns:
|
||||
True if initialized successfully
|
||||
"""
|
||||
global _scheduler_service, _task_store
|
||||
|
||||
try:
|
||||
from agent.tools.scheduler.task_store import TaskStore
|
||||
from agent.tools.scheduler.scheduler_service import SchedulerService
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
|
||||
|
||||
# Create task store
|
||||
_task_store = TaskStore(store_path)
|
||||
logger.debug(f"[Scheduler] Task store initialized: {store_path}")
|
||||
|
||||
# Create execute callback
|
||||
def execute_task_callback(task: dict):
|
||||
"""Callback to execute a scheduled task"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "agent_task":
|
||||
_execute_agent_task(task, agent_bridge)
|
||||
elif action_type == "send_message":
|
||||
# Legacy support for old tasks
|
||||
_execute_send_message(task, agent_bridge)
|
||||
elif action_type == "tool_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_tool_call(task, agent_bridge)
|
||||
elif action_type == "skill_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_skill_call(task, agent_bridge)
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Unknown action type: {action_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
|
||||
|
||||
# Create scheduler service
|
||||
_scheduler_service = SchedulerService(_task_store, execute_task_callback)
|
||||
_scheduler_service.start()
|
||||
|
||||
logger.debug("[Scheduler] Scheduler service initialized and started")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_task_store():
|
||||
"""Get the global task store instance"""
|
||||
return _task_store
|
||||
|
||||
|
||||
def get_scheduler_service():
|
||||
"""Get the global scheduler service instance"""
|
||||
return _scheduler_service
|
||||
|
||||
|
||||
def _execute_agent_task(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute an agent_task action - let Agent handle the task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
task_description = action.get("task_description")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not task_description:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No task_description specified")
|
||||
return
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Check for unsupported channels
|
||||
if channel_type == "dingtalk":
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk channel does not support scheduled messages (Stream mode limitation). Task will execute but message cannot be sent.")
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing agent task '{task_description}'")
|
||||
|
||||
# Create context for Agent
|
||||
context = Context(ContextType.TEXT, task_description)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific setup
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk requires msg object, set to None for scheduled tasks
|
||||
context["msg"] = None
|
||||
# 如果是单聊,需要传递 sender_staff_id
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
|
||||
# Use Agent to execute the task
|
||||
# Mark this as a scheduled task execution to prevent recursive task creation
|
||||
context["is_scheduled_task"] = True
|
||||
|
||||
try:
|
||||
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=True)
|
||||
|
||||
if reply and reply.content:
|
||||
# Send the reply via channel
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register request_id
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
request_id = context.get("request_id")
|
||||
if request_id:
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
# Send the reply
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send result: {e}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from agent execution")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute task via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_agent_task: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def _execute_send_message(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a send_message action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
content = action.get("content", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Create context for sending message
|
||||
context = Context(ContextType.TEXT, content)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific context setup
|
||||
if channel_type == "web":
|
||||
# Web channel needs request_id
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
|
||||
elif channel_type == "feishu":
|
||||
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
|
||||
# Use chat_id for groups, open_id for private chats
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
# Keep isgroup as is, but set msg to None (no original message to reply to)
|
||||
# Feishu channel will detect this and send as new message instead of reply
|
||||
context["msg"] = None
|
||||
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk channel setup
|
||||
context["msg"] = None
|
||||
# 如果是单聊,需要传递 sender_staff_id
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}")
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id")
|
||||
|
||||
# Create reply
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_send_message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def _execute_tool_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a tool_call action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
tool_name = action.get("call_name") or action.get("tool_name")
|
||||
tool_params = action.get("call_params") or action.get("tool_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not tool_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No tool_name specified")
|
||||
return
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Get tool manager and create tool instance
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
tool_manager = ToolManager()
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
if not tool:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: Tool '{tool_name}' not found")
|
||||
return
|
||||
|
||||
# Execute tool
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing tool '{tool_name}' with params {tool_params}")
|
||||
result = tool.execute(tool_params)
|
||||
|
||||
# Get result content
|
||||
if hasattr(result, 'result'):
|
||||
content = result.result
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
# Add prefix if specified
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
# Send result as message
|
||||
context = Context(ContextType.TEXT, content)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific context setup
|
||||
if channel_type == "web":
|
||||
# Web channel needs request_id
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
|
||||
elif channel_type == "feishu":
|
||||
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
|
||||
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send tool result: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_tool_call: {e}")
|
||||
|
||||
|
||||
def _execute_skill_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a skill_call action by asking Agent to run the skill
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
skill_name = action.get("call_name") or action.get("skill_name")
|
||||
skill_params = action.get("call_params") or action.get("skill_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("isgroup", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not skill_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No skill_name specified")
|
||||
return
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}")
|
||||
|
||||
# Build a natural language query for the Agent to execute the skill
|
||||
# Format: "Use skill-name to do something with params"
|
||||
param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()])
|
||||
query = f"Use {skill_name} skill"
|
||||
if param_str:
|
||||
query += f" with {param_str}"
|
||||
|
||||
# Create context for Agent
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific setup
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the skill
|
||||
try:
|
||||
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=True)
|
||||
|
||||
if reply and reply.content:
|
||||
content = reply.content
|
||||
|
||||
# Add prefix if specified
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute skill via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_skill_call: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
"""
|
||||
Attach scheduler components to a SchedulerTool instance
|
||||
|
||||
Args:
|
||||
tool: SchedulerTool instance
|
||||
context: Current context (optional)
|
||||
"""
|
||||
if _task_store:
|
||||
tool.task_store = _task_store
|
||||
|
||||
if context:
|
||||
tool.current_context = context
|
||||
|
||||
# Also set channel_type from config
|
||||
channel_type = conf().get("channel_type", "unknown")
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = channel_type
|
||||
220
agent/tools/scheduler/scheduler_service.py
Normal file
220
agent/tools/scheduler/scheduler_service.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Background scheduler service for executing scheduled tasks
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Optional
|
||||
from croniter import croniter
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""
|
||||
Background service that executes scheduled tasks
|
||||
"""
|
||||
|
||||
def __init__(self, task_store, execute_callback: Callable):
|
||||
"""
|
||||
Initialize scheduler service
|
||||
|
||||
Args:
|
||||
task_store: TaskStore instance
|
||||
execute_callback: Function to call when executing a task
|
||||
"""
|
||||
self.task_store = task_store
|
||||
self.execute_callback = execute_callback
|
||||
self.running = False
|
||||
self.thread = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def start(self):
|
||||
"""Start the scheduler service"""
|
||||
with self._lock:
|
||||
if self.running:
|
||||
logger.warning("[Scheduler] Service already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self.thread.start()
|
||||
logger.debug("[Scheduler] Service started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the scheduler service"""
|
||||
with self._lock:
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
if self.thread:
|
||||
self.thread.join(timeout=5)
|
||||
logger.info("[Scheduler] Service stopped")
|
||||
|
||||
def _run_loop(self):
|
||||
"""Main scheduler loop"""
|
||||
logger.debug("[Scheduler] Scheduler loop started")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self._check_and_execute_tasks()
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in scheduler loop: {e}")
|
||||
|
||||
# Sleep for 30 seconds between checks
|
||||
time.sleep(30)
|
||||
|
||||
def _check_and_execute_tasks(self):
|
||||
"""Check for due tasks and execute them"""
|
||||
now = datetime.now()
|
||||
tasks = self.task_store.list_tasks(enabled_only=True)
|
||||
|
||||
for task in tasks:
|
||||
try:
|
||||
# Check if task is due
|
||||
if self._is_task_due(task, now):
|
||||
logger.info(f"[Scheduler] Executing task: {task['id']} - {task['name']}")
|
||||
self._execute_task(task)
|
||||
|
||||
# Update next run time
|
||||
next_run = self._calculate_next_run(task, now)
|
||||
if next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_run.isoformat(),
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
else:
|
||||
# One-time task, disable it
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task completed and disabled: {task['id']}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}")
|
||||
|
||||
def _is_task_due(self, task: dict, now: datetime) -> bool:
|
||||
"""
|
||||
Check if a task is due to run
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
now: Current datetime
|
||||
|
||||
Returns:
|
||||
True if task should run now
|
||||
"""
|
||||
next_run_str = task.get("next_run_at")
|
||||
if not next_run_str:
|
||||
# Calculate initial next_run_at
|
||||
next_run = self._calculate_next_run(task, now)
|
||||
if next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_run.isoformat()
|
||||
})
|
||||
return False
|
||||
return False
|
||||
|
||||
try:
|
||||
next_run = datetime.fromisoformat(next_run_str)
|
||||
|
||||
# Check if task is overdue (e.g., service restart)
|
||||
if next_run < now:
|
||||
time_diff = (now - next_run).total_seconds()
|
||||
|
||||
# If overdue by more than 5 minutes, skip this run and schedule next
|
||||
if time_diff > 300: # 5 minutes
|
||||
logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run")
|
||||
|
||||
# For one-time tasks, disable them
|
||||
schedule = task.get("schedule", {})
|
||||
if schedule.get("type") == "once":
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, disabled")
|
||||
return False
|
||||
|
||||
# For recurring tasks, calculate next run from now
|
||||
next_next_run = self._calculate_next_run(task, now)
|
||||
if next_next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_next_run.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
|
||||
return False
|
||||
|
||||
return now >= next_run
|
||||
except:
|
||||
return False
|
||||
|
||||
def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]:
|
||||
"""
|
||||
Calculate next run time for a task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
from_time: Calculate from this time
|
||||
|
||||
Returns:
|
||||
Next run datetime or None for one-time tasks
|
||||
"""
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
if schedule_type == "cron":
|
||||
# Cron expression
|
||||
expression = schedule.get("expression")
|
||||
if not expression:
|
||||
return None
|
||||
|
||||
try:
|
||||
cron = croniter(expression, from_time)
|
||||
return cron.get_next(datetime)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Invalid cron expression '{expression}': {e}")
|
||||
return None
|
||||
|
||||
elif schedule_type == "interval":
|
||||
# Interval in seconds
|
||||
seconds = schedule.get("seconds", 0)
|
||||
if seconds <= 0:
|
||||
return None
|
||||
return from_time + timedelta(seconds=seconds)
|
||||
|
||||
elif schedule_type == "once":
|
||||
# One-time task at specific time
|
||||
run_at_str = schedule.get("run_at")
|
||||
if not run_at_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
run_at = datetime.fromisoformat(run_at_str)
|
||||
# Only return if in the future
|
||||
if run_at > from_time:
|
||||
return run_at
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _execute_task(self, task: dict):
|
||||
"""
|
||||
Execute a task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
"""
|
||||
try:
|
||||
# Call the execute callback
|
||||
self.execute_callback(task)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task['id']}: {e}")
|
||||
# Update task with error
|
||||
self.task_store.update_task(task['id'], {
|
||||
"last_error": str(e),
|
||||
"last_error_at": datetime.now().isoformat()
|
||||
})
|
||||
442
agent/tools/scheduler/scheduler_tool.py
Normal file
442
agent/tools/scheduler/scheduler_tool.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
Scheduler tool for creating and managing scheduled tasks
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from croniter import croniter
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class SchedulerTool(BaseTool):
|
||||
"""
|
||||
Tool for managing scheduled tasks (reminders, notifications, etc.)
|
||||
"""
|
||||
|
||||
name: str = "scheduler"
|
||||
description: str = (
|
||||
"创建、查询和管理定时任务。支持固定消息和AI任务两种类型。\n\n"
|
||||
"使用方法:\n"
|
||||
"- 创建:action='create', name='任务名', message/ai_task='内容', schedule_type='once/interval/cron', schedule_value='...'\n"
|
||||
"- 查询:action='list' / action='get', task_id='任务ID'\n"
|
||||
"- 管理:action='delete/enable/disable', task_id='任务ID'\n\n"
|
||||
"调度类型:\n"
|
||||
"- once: 一次性任务,支持相对时间(+5s,+10m,+1h,+1d)或ISO时间\n"
|
||||
"- interval: 固定间隔(秒),如3600表示每小时\n"
|
||||
"- cron: cron表达式,如'0 8 * * *'表示每天8点\n\n"
|
||||
"注意:'X秒后'用once+相对时间,'每X秒'用interval"
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["create", "list", "get", "delete", "enable", "disable"],
|
||||
"description": "操作类型: create(创建), list(列表), get(查询), delete(删除), enable(启用), disable(禁用)"
|
||||
},
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "任务ID (用于 get/delete/enable/disable 操作)"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "任务名称 (用于 create 操作)"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "固定消息内容 (与ai_task二选一)"
|
||||
},
|
||||
"ai_task": {
|
||||
"type": "string",
|
||||
"description": "AI任务描述 (与message二选一),如'搜索今日新闻'、'查询天气'"
|
||||
},
|
||||
"schedule_type": {
|
||||
"type": "string",
|
||||
"enum": ["cron", "interval", "once"],
|
||||
"description": "调度类型 (用于 create 操作): cron(cron表达式), interval(固定间隔秒数), once(一次性)"
|
||||
},
|
||||
"schedule_value": {
|
||||
"type": "string",
|
||||
"description": "调度值: cron表达式/间隔秒数/时间(+5s,+10m,+1h或ISO格式)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
super().__init__()
|
||||
self.config = config or {}
|
||||
|
||||
# Will be set by agent bridge
|
||||
self.task_store = None
|
||||
self.current_context = None
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
"""
|
||||
Execute scheduler operations
|
||||
|
||||
Args:
|
||||
params: Dictionary containing:
|
||||
- action: Operation type (create/list/get/delete/enable/disable)
|
||||
- Other parameters depending on action
|
||||
|
||||
Returns:
|
||||
ToolResult object
|
||||
"""
|
||||
# Extract parameters
|
||||
action = params.get("action")
|
||||
kwargs = params
|
||||
|
||||
if not self.task_store:
|
||||
return ToolResult.fail("错误: 定时任务系统未初始化")
|
||||
|
||||
try:
|
||||
if action == "create":
|
||||
result = self._create_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "list":
|
||||
result = self._list_tasks(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "get":
|
||||
result = self._get_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "delete":
|
||||
result = self._delete_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "enable":
|
||||
result = self._enable_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "disable":
|
||||
result = self._disable_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
else:
|
||||
return ToolResult.fail(f"未知操作: {action}")
|
||||
except Exception as e:
|
||||
logger.error(f"[SchedulerTool] Error: {e}")
|
||||
return ToolResult.fail(f"操作失败: {str(e)}")
|
||||
|
||||
def _create_task(self, **kwargs) -> str:
|
||||
"""Create a new scheduled task"""
|
||||
name = kwargs.get("name")
|
||||
message = kwargs.get("message")
|
||||
ai_task = kwargs.get("ai_task")
|
||||
schedule_type = kwargs.get("schedule_type")
|
||||
schedule_value = kwargs.get("schedule_value")
|
||||
|
||||
# Validate required fields
|
||||
if not name:
|
||||
return "错误: 缺少任务名称 (name)"
|
||||
|
||||
# Check that exactly one of message/ai_task is provided
|
||||
if not message and not ai_task:
|
||||
return "错误: 必须提供 message(固定消息)或 ai_task(AI任务)之一"
|
||||
if message and ai_task:
|
||||
return "错误: message 和 ai_task 只能提供其中一个"
|
||||
|
||||
if not schedule_type:
|
||||
return "错误: 缺少调度类型 (schedule_type)"
|
||||
if not schedule_value:
|
||||
return "错误: 缺少调度值 (schedule_value)"
|
||||
|
||||
# Validate schedule
|
||||
schedule = self._parse_schedule(schedule_type, schedule_value)
|
||||
if not schedule:
|
||||
return f"错误: 无效的调度配置 - type: {schedule_type}, value: {schedule_value}"
|
||||
|
||||
# Get context info for receiver
|
||||
if not self.current_context:
|
||||
return "错误: 无法获取当前对话上下文"
|
||||
|
||||
context = self.current_context
|
||||
|
||||
# Create task
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Build action based on message or ai_task
|
||||
if message:
|
||||
action = {
|
||||
"type": "send_message",
|
||||
"content": message,
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
}
|
||||
else: # ai_task
|
||||
action = {
|
||||
"type": "agent_task",
|
||||
"task_description": ai_task,
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
}
|
||||
|
||||
# 针对钉钉单聊,额外存储 sender_staff_id
|
||||
msg = context.kwargs.get("msg")
|
||||
if msg and hasattr(msg, 'sender_staff_id') and not context.get("isgroup", False):
|
||||
action["dingtalk_sender_staff_id"] = msg.sender_staff_id
|
||||
|
||||
task_data = {
|
||||
"id": task_id,
|
||||
"name": name,
|
||||
"enabled": True,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"schedule": schedule,
|
||||
"action": action
|
||||
}
|
||||
|
||||
# Calculate initial next_run_at
|
||||
next_run = self._calculate_next_run(task_data)
|
||||
if next_run:
|
||||
task_data["next_run_at"] = next_run.isoformat()
|
||||
|
||||
# Save task
|
||||
self.task_store.add_task(task_data)
|
||||
|
||||
# Format response
|
||||
schedule_desc = self._format_schedule_description(schedule)
|
||||
receiver_desc = task_data["action"]["receiver_name"] or task_data["action"]["receiver"]
|
||||
|
||||
if message:
|
||||
content_desc = f"💬 固定消息: {message}"
|
||||
else:
|
||||
content_desc = f"🤖 AI任务: {ai_task}"
|
||||
|
||||
return (
|
||||
f"✅ 定时任务创建成功\n\n"
|
||||
f"📋 任务ID: {task_id}\n"
|
||||
f"📝 名称: {name}\n"
|
||||
f"⏰ 调度: {schedule_desc}\n"
|
||||
f"👤 接收者: {receiver_desc}\n"
|
||||
f"{content_desc}\n"
|
||||
f"🕐 下次执行: {next_run.strftime('%Y-%m-%d %H:%M:%S') if next_run else '未知'}"
|
||||
)
|
||||
|
||||
def _list_tasks(self, **kwargs) -> str:
|
||||
"""List all tasks"""
|
||||
tasks = self.task_store.list_tasks()
|
||||
|
||||
if not tasks:
|
||||
return "📋 暂无定时任务"
|
||||
|
||||
lines = [f"📋 定时任务列表 (共 {len(tasks)} 个)\n"]
|
||||
|
||||
for task in tasks:
|
||||
status = "✅" if task.get("enabled", True) else "❌"
|
||||
schedule_desc = self._format_schedule_description(task.get("schedule", {}))
|
||||
next_run = task.get("next_run_at")
|
||||
next_run_str = datetime.fromisoformat(next_run).strftime('%m-%d %H:%M') if next_run else "未知"
|
||||
|
||||
lines.append(
|
||||
f"{status} [{task['id']}] {task['name']}\n"
|
||||
f" ⏰ {schedule_desc} | 下次: {next_run_str}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _get_task(self, **kwargs) -> str:
|
||||
"""Get task details"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
status = "启用" if task.get("enabled", True) else "禁用"
|
||||
schedule_desc = self._format_schedule_description(task.get("schedule", {}))
|
||||
action = task.get("action", {})
|
||||
next_run = task.get("next_run_at")
|
||||
next_run_str = datetime.fromisoformat(next_run).strftime('%Y-%m-%d %H:%M:%S') if next_run else "未知"
|
||||
last_run = task.get("last_run_at")
|
||||
last_run_str = datetime.fromisoformat(last_run).strftime('%Y-%m-%d %H:%M:%S') if last_run else "从未执行"
|
||||
|
||||
return (
|
||||
f"📋 任务详情\n\n"
|
||||
f"ID: {task['id']}\n"
|
||||
f"名称: {task['name']}\n"
|
||||
f"状态: {status}\n"
|
||||
f"调度: {schedule_desc}\n"
|
||||
f"接收者: {action.get('receiver_name', action.get('receiver'))}\n"
|
||||
f"消息: {action.get('content')}\n"
|
||||
f"下次执行: {next_run_str}\n"
|
||||
f"上次执行: {last_run_str}\n"
|
||||
f"创建时间: {datetime.fromisoformat(task['created_at']).strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
|
||||
def _delete_task(self, **kwargs) -> str:
|
||||
"""Delete a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.delete_task(task_id)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已删除"
|
||||
|
||||
def _enable_task(self, **kwargs) -> str:
|
||||
"""Enable a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.enable_task(task_id, True)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已启用"
|
||||
|
||||
def _disable_task(self, **kwargs) -> str:
|
||||
"""Disable a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.enable_task(task_id, False)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已禁用"
|
||||
|
||||
def _parse_schedule(self, schedule_type: str, schedule_value: str) -> Optional[dict]:
|
||||
"""Parse and validate schedule configuration"""
|
||||
try:
|
||||
if schedule_type == "cron":
|
||||
# Validate cron expression
|
||||
croniter(schedule_value)
|
||||
return {"type": "cron", "expression": schedule_value}
|
||||
|
||||
elif schedule_type == "interval":
|
||||
# Parse interval in seconds
|
||||
seconds = int(schedule_value)
|
||||
if seconds <= 0:
|
||||
return None
|
||||
return {"type": "interval", "seconds": seconds}
|
||||
|
||||
elif schedule_type == "once":
|
||||
# Parse datetime - support both relative and absolute time
|
||||
|
||||
# Check if it's relative time (e.g., "+5s", "+10m", "+1h", "+1d")
|
||||
if schedule_value.startswith("+"):
|
||||
import re
|
||||
match = re.match(r'\+(\d+)([smhd])', schedule_value)
|
||||
if match:
|
||||
amount = int(match.group(1))
|
||||
unit = match.group(2)
|
||||
|
||||
from datetime import timedelta
|
||||
now = datetime.now()
|
||||
|
||||
if unit == 's': # seconds
|
||||
target_time = now + timedelta(seconds=amount)
|
||||
elif unit == 'm': # minutes
|
||||
target_time = now + timedelta(minutes=amount)
|
||||
elif unit == 'h': # hours
|
||||
target_time = now + timedelta(hours=amount)
|
||||
elif unit == 'd': # days
|
||||
target_time = now + timedelta(days=amount)
|
||||
else:
|
||||
return None
|
||||
|
||||
return {"type": "once", "run_at": target_time.isoformat()}
|
||||
else:
|
||||
logger.error(f"[SchedulerTool] Invalid relative time format: {schedule_value}")
|
||||
return None
|
||||
else:
|
||||
# Absolute time in ISO format
|
||||
datetime.fromisoformat(schedule_value)
|
||||
return {"type": "once", "run_at": schedule_value}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SchedulerTool] Invalid schedule: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_next_run(self, task: dict) -> Optional[datetime]:
|
||||
"""Calculate next run time for a task"""
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
now = datetime.now()
|
||||
|
||||
if schedule_type == "cron":
|
||||
expression = schedule.get("expression")
|
||||
cron = croniter(expression, now)
|
||||
return cron.get_next(datetime)
|
||||
|
||||
elif schedule_type == "interval":
|
||||
seconds = schedule.get("seconds", 0)
|
||||
from datetime import timedelta
|
||||
return now + timedelta(seconds=seconds)
|
||||
|
||||
elif schedule_type == "once":
|
||||
run_at_str = schedule.get("run_at")
|
||||
return datetime.fromisoformat(run_at_str)
|
||||
|
||||
return None
|
||||
|
||||
def _format_schedule_description(self, schedule: dict) -> str:
|
||||
"""Format schedule as human-readable description"""
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
if schedule_type == "cron":
|
||||
expr = schedule.get("expression", "")
|
||||
# Try to provide friendly description
|
||||
if expr == "0 9 * * *":
|
||||
return "每天 9:00"
|
||||
elif expr == "0 */1 * * *":
|
||||
return "每小时"
|
||||
elif expr == "*/30 * * * *":
|
||||
return "每30分钟"
|
||||
else:
|
||||
return f"Cron: {expr}"
|
||||
|
||||
elif schedule_type == "interval":
|
||||
seconds = schedule.get("seconds", 0)
|
||||
if seconds >= 86400:
|
||||
days = seconds // 86400
|
||||
return f"每 {days} 天"
|
||||
elif seconds >= 3600:
|
||||
hours = seconds // 3600
|
||||
return f"每 {hours} 小时"
|
||||
elif seconds >= 60:
|
||||
minutes = seconds // 60
|
||||
return f"每 {minutes} 分钟"
|
||||
else:
|
||||
return f"每 {seconds} 秒"
|
||||
|
||||
elif schedule_type == "once":
|
||||
run_at = schedule.get("run_at", "")
|
||||
try:
|
||||
dt = datetime.fromisoformat(run_at)
|
||||
return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})"
|
||||
except:
|
||||
return "一次性"
|
||||
|
||||
return "未知"
|
||||
|
||||
def _get_receiver_name(self, context: Context) -> str:
|
||||
"""Get receiver name from context"""
|
||||
try:
|
||||
msg = context.get("msg")
|
||||
if msg:
|
||||
if context.get("isgroup"):
|
||||
return msg.other_user_nickname or "群聊"
|
||||
else:
|
||||
return msg.from_user_nickname or "用户"
|
||||
except:
|
||||
pass
|
||||
return "未知"
|
||||
200
agent/tools/scheduler/task_store.py
Normal file
200
agent/tools/scheduler/task_store.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Task storage management for scheduler
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TaskStore:
|
||||
"""
|
||||
Manages persistent storage of scheduled tasks
|
||||
"""
|
||||
|
||||
def __init__(self, store_path: str = None):
|
||||
"""
|
||||
Initialize task store
|
||||
|
||||
Args:
|
||||
store_path: Path to tasks.json file. Defaults to ~/cow/scheduler/tasks.json
|
||||
"""
|
||||
if store_path is None:
|
||||
# Default to ~/cow/scheduler/tasks.json
|
||||
home = os.path.expanduser("~")
|
||||
store_path = os.path.join(home, "cow", "scheduler", "tasks.json")
|
||||
|
||||
self.store_path = store_path
|
||||
self.lock = threading.Lock()
|
||||
self._ensure_store_dir()
|
||||
|
||||
def _ensure_store_dir(self):
|
||||
"""Ensure the storage directory exists"""
|
||||
store_dir = os.path.dirname(self.store_path)
|
||||
os.makedirs(store_dir, exist_ok=True)
|
||||
|
||||
def load_tasks(self) -> Dict[str, dict]:
|
||||
"""
|
||||
Load all tasks from storage
|
||||
|
||||
Returns:
|
||||
Dictionary of task_id -> task_data
|
||||
"""
|
||||
with self.lock:
|
||||
if not os.path.exists(self.store_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(self.store_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data.get("tasks", {})
|
||||
except Exception as e:
|
||||
print(f"Error loading tasks: {e}")
|
||||
return {}
|
||||
|
||||
def save_tasks(self, tasks: Dict[str, dict]):
|
||||
"""
|
||||
Save all tasks to storage
|
||||
|
||||
Args:
|
||||
tasks: Dictionary of task_id -> task_data
|
||||
"""
|
||||
with self.lock:
|
||||
try:
|
||||
# Create backup
|
||||
if os.path.exists(self.store_path):
|
||||
backup_path = f"{self.store_path}.bak"
|
||||
try:
|
||||
with open(self.store_path, 'r') as src:
|
||||
with open(backup_path, 'w') as dst:
|
||||
dst.write(src.read())
|
||||
except:
|
||||
pass
|
||||
|
||||
# Save tasks
|
||||
data = {
|
||||
"version": 1,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"tasks": tasks
|
||||
}
|
||||
|
||||
with open(self.store_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving tasks: {e}")
|
||||
raise
|
||||
|
||||
def add_task(self, task: dict) -> bool:
|
||||
"""
|
||||
Add a new task
|
||||
|
||||
Args:
|
||||
task: Task data dictionary
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
task_id = task.get("id")
|
||||
|
||||
if not task_id:
|
||||
raise ValueError("Task must have an 'id' field")
|
||||
|
||||
if task_id in tasks:
|
||||
raise ValueError(f"Task with id '{task_id}' already exists")
|
||||
|
||||
tasks[task_id] = task
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def update_task(self, task_id: str, updates: dict) -> bool:
|
||||
"""
|
||||
Update an existing task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
updates: Dictionary of fields to update
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
|
||||
if task_id not in tasks:
|
||||
raise ValueError(f"Task '{task_id}' not found")
|
||||
|
||||
# Update fields
|
||||
tasks[task_id].update(updates)
|
||||
tasks[task_id]["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def delete_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
Delete a task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
|
||||
if task_id not in tasks:
|
||||
raise ValueError(f"Task '{task_id}' not found")
|
||||
|
||||
del tasks[task_id]
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get a specific task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Task data or None if not found
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
return tasks.get(task_id)
|
||||
|
||||
def list_tasks(self, enabled_only: bool = False) -> List[dict]:
|
||||
"""
|
||||
List all tasks
|
||||
|
||||
Args:
|
||||
enabled_only: If True, only return enabled tasks
|
||||
|
||||
Returns:
|
||||
List of task dictionaries
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
task_list = list(tasks.values())
|
||||
|
||||
if enabled_only:
|
||||
task_list = [t for t in task_list if t.get("enabled", True)]
|
||||
|
||||
# Sort by next_run_at
|
||||
task_list.sort(key=lambda t: t.get("next_run_at", float('inf')))
|
||||
|
||||
return task_list
|
||||
|
||||
def enable_task(self, task_id: str, enabled: bool = True) -> bool:
|
||||
"""
|
||||
Enable or disable a task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
enabled: True to enable, False to disable
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
return self.update_task(task_id, {"enabled": enabled})
|
||||
3
agent/tools/send/__init__.py
Normal file
3
agent/tools/send/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .send import Send
|
||||
|
||||
__all__ = ['Send']
|
||||
159
agent/tools/send/send.py
Normal file
159
agent/tools/send/send.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Send tool - Send files to the user
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class Send(BaseTool):
|
||||
"""Tool for sending files to the user"""
|
||||
|
||||
name: str = "send"
|
||||
description: str = "Send a file (image, video, audio, document) to the user. Use this when the user explicitly asks to send/share a file."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to send. Can be absolute path or relative to workspace."
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Optional message to accompany the file"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
# Supported file types
|
||||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
|
||||
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
|
||||
self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
|
||||
self.document_extensions = {'.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.txt', '.md'}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file send operation
|
||||
|
||||
:param args: Contains file path and optional message
|
||||
:return: File metadata for channel to send
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
message = args.get("message", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable
|
||||
if not os.access(absolute_path, os.R_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable: {path}")
|
||||
|
||||
# Get file info
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
file_name = Path(absolute_path).name
|
||||
|
||||
# Determine file type
|
||||
if file_ext in self.image_extensions:
|
||||
file_type = "image"
|
||||
mime_type = self._get_image_mime_type(file_ext)
|
||||
elif file_ext in self.video_extensions:
|
||||
file_type = "video"
|
||||
mime_type = self._get_video_mime_type(file_ext)
|
||||
elif file_ext in self.audio_extensions:
|
||||
file_type = "audio"
|
||||
mime_type = self._get_audio_mime_type(file_ext)
|
||||
elif file_ext in self.document_extensions:
|
||||
file_type = "document"
|
||||
mime_type = self._get_document_mime_type(file_ext)
|
||||
else:
|
||||
file_type = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# Return file_to_send metadata
|
||||
result = {
|
||||
"type": "file_to_send",
|
||||
"file_type": file_type,
|
||||
"path": absolute_path,
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": self._format_size(file_size),
|
||||
"message": message or f"正在发送 {file_name}"
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _get_image_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for image"""
|
||||
mime_map = {
|
||||
'.jpg': 'image/jpeg', '.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png', '.gif': 'image/gif',
|
||||
'.webp': 'image/webp', '.bmp': 'image/bmp',
|
||||
'.svg': 'image/svg+xml', '.ico': 'image/x-icon'
|
||||
}
|
||||
return mime_map.get(ext, 'image/jpeg')
|
||||
|
||||
def _get_video_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for video"""
|
||||
mime_map = {
|
||||
'.mp4': 'video/mp4', '.avi': 'video/x-msvideo',
|
||||
'.mov': 'video/quicktime', '.mkv': 'video/x-matroska',
|
||||
'.webm': 'video/webm', '.flv': 'video/x-flv'
|
||||
}
|
||||
return mime_map.get(ext, 'video/mp4')
|
||||
|
||||
def _get_audio_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for audio"""
|
||||
mime_map = {
|
||||
'.mp3': 'audio/mpeg', '.wav': 'audio/wav',
|
||||
'.ogg': 'audio/ogg', '.m4a': 'audio/mp4',
|
||||
'.flac': 'audio/flac', '.aac': 'audio/aac'
|
||||
}
|
||||
return mime_map.get(ext, 'audio/mpeg')
|
||||
|
||||
def _get_document_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for document"""
|
||||
mime_map = {
|
||||
'.pdf': 'application/pdf',
|
||||
'.doc': 'application/msword',
|
||||
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'.xls': 'application/vnd.ms-excel',
|
||||
'.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'.ppt': 'application/vnd.ms-powerpoint',
|
||||
'.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'.txt': 'text/plain',
|
||||
'.md': 'text/markdown'
|
||||
}
|
||||
return mime_map.get(ext, 'application/octet-stream')
|
||||
|
||||
def _format_size(self, size_bytes: int) -> str:
|
||||
"""Format file size in human-readable format"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.1f}{unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.1f}TB"
|
||||
@@ -1,212 +0,0 @@
|
||||
# WebFetch Tool
|
||||
|
||||
免费的网页抓取工具,无需 API Key,可直接抓取网页内容并提取可读文本。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- ✅ **完全免费** - 无需任何 API Key
|
||||
- 🌐 **智能提取** - 自动提取网页主要内容
|
||||
- 📝 **格式转换** - 支持 HTML → Markdown/Text
|
||||
- 🚀 **高性能** - 内置请求重试和超时控制
|
||||
- 🎯 **智能降级** - 优先使用 Readability,可降级到基础提取
|
||||
|
||||
## 安装依赖
|
||||
|
||||
### 基础功能(必需)
|
||||
```bash
|
||||
pip install requests
|
||||
```
|
||||
|
||||
### 增强功能(推荐)
|
||||
```bash
|
||||
# 安装 readability-lxml 以获得更好的内容提取效果
|
||||
pip install readability-lxml
|
||||
|
||||
# 安装 html2text 以获得更好的 Markdown 转换
|
||||
pip install html2text
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 在代码中使用
|
||||
|
||||
```python
|
||||
from agent.tools.web_fetch import WebFetch
|
||||
|
||||
# 创建工具实例
|
||||
tool = WebFetch()
|
||||
|
||||
# 抓取网页(默认返回 Markdown 格式)
|
||||
result = tool.execute({
|
||||
"url": "https://example.com"
|
||||
})
|
||||
|
||||
# 抓取并转换为纯文本
|
||||
result = tool.execute({
|
||||
"url": "https://example.com",
|
||||
"extract_mode": "text",
|
||||
"max_chars": 5000
|
||||
})
|
||||
|
||||
if result.status == "success":
|
||||
data = result.result
|
||||
print(f"标题: {data['title']}")
|
||||
print(f"内容: {data['text']}")
|
||||
```
|
||||
|
||||
### 2. 在 Agent 中使用
|
||||
|
||||
工具会自动加载到 Agent 的工具列表中:
|
||||
|
||||
```python
|
||||
from agent.tools import WebFetch
|
||||
|
||||
tools = [
|
||||
WebFetch(),
|
||||
# ... 其他工具
|
||||
]
|
||||
|
||||
agent = create_agent(tools=tools)
|
||||
```
|
||||
|
||||
### 3. 通过 Skills 使用
|
||||
|
||||
创建一个 skill 文件 `skills/web-fetch/SKILL.md`:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: web-fetch
|
||||
emoji: 🌐
|
||||
always: true
|
||||
---
|
||||
|
||||
# 网页内容获取
|
||||
|
||||
使用 web_fetch 工具获取网页内容。
|
||||
|
||||
## 使用场景
|
||||
|
||||
- 需要读取某个网页的内容
|
||||
- 需要提取文章正文
|
||||
- 需要获取网页信息
|
||||
|
||||
## 示例
|
||||
|
||||
<example>
|
||||
用户: 帮我看看 https://example.com 这个网页讲了什么
|
||||
助手: <tool_use name="web_fetch">
|
||||
<url>https://example.com</url>
|
||||
<extract_mode>markdown</extract_mode>
|
||||
</tool_use>
|
||||
</example>
|
||||
```
|
||||
|
||||
## 参数说明
|
||||
|
||||
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| `url` | string | ✅ | - | 要抓取的 URL(http/https) |
|
||||
| `extract_mode` | string | ❌ | `markdown` | 提取模式:`markdown` 或 `text` |
|
||||
| `max_chars` | integer | ❌ | `50000` | 最大返回字符数(最小 100) |
|
||||
|
||||
## 返回结果
|
||||
|
||||
```python
|
||||
{
|
||||
"url": "https://example.com", # 最终 URL(处理重定向后)
|
||||
"status": 200, # HTTP 状态码
|
||||
"content_type": "text/html", # 内容类型
|
||||
"title": "Example Domain", # 页面标题
|
||||
"extractor": "readability", # 提取器:readability/basic/raw
|
||||
"extract_mode": "markdown", # 提取模式
|
||||
"text": "# Example Domain\n\n...", # 提取的文本内容
|
||||
"length": 1234, # 文本长度
|
||||
"truncated": false, # 是否被截断
|
||||
"warning": "..." # 警告信息(如果有)
|
||||
}
|
||||
```
|
||||
|
||||
## 与其他搜索工具的对比
|
||||
|
||||
| 工具 | 需要 API Key | 功能 | 成本 |
|
||||
|------|-------------|------|------|
|
||||
| `web_fetch` | ❌ 不需要 | 抓取指定 URL 的内容 | 免费 |
|
||||
| `web_search` (Brave) | ✅ 需要 | 搜索引擎查询 | 有免费额度 |
|
||||
| `web_search` (Perplexity) | ✅ 需要 | AI 搜索 + 引用 | 付费 |
|
||||
| `browser` | ❌ 不需要 | 完整浏览器自动化 | 免费但资源占用大 |
|
||||
| `google_search` | ✅ 需要 | Google 搜索 API | 付费 |
|
||||
|
||||
## 技术细节
|
||||
|
||||
### 内容提取策略
|
||||
|
||||
1. **Readability 模式**(推荐)
|
||||
- 使用 Mozilla 的 Readability 算法
|
||||
- 自动识别文章主体内容
|
||||
- 过滤广告、导航栏等噪音
|
||||
|
||||
2. **Basic 模式**(降级)
|
||||
- 简单的 HTML 标签清理
|
||||
- 正则表达式提取文本
|
||||
- 适用于简单页面
|
||||
|
||||
3. **Raw 模式**
|
||||
- 用于非 HTML 内容
|
||||
- 直接返回原始内容
|
||||
|
||||
### 错误处理
|
||||
|
||||
工具会自动处理以下情况:
|
||||
- ✅ HTTP 重定向(最多 3 次)
|
||||
- ✅ 请求超时(默认 30 秒)
|
||||
- ✅ 网络错误自动重试
|
||||
- ✅ 内容提取失败降级
|
||||
|
||||
## 测试
|
||||
|
||||
运行测试脚本:
|
||||
|
||||
```bash
|
||||
cd agent/tools/web_fetch
|
||||
python test_web_fetch.py
|
||||
```
|
||||
|
||||
## 配置选项
|
||||
|
||||
在创建工具时可以传入配置:
|
||||
|
||||
```python
|
||||
tool = WebFetch(config={
|
||||
"timeout": 30, # 请求超时时间(秒)
|
||||
"max_redirects": 3, # 最大重定向次数
|
||||
"user_agent": "..." # 自定义 User-Agent
|
||||
})
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q: 为什么推荐安装 readability-lxml?
|
||||
|
||||
A: readability-lxml 提供更好的内容提取质量,能够:
|
||||
- 自动识别文章主体
|
||||
- 过滤广告和导航栏
|
||||
- 保留文章结构
|
||||
|
||||
没有它也能工作,但提取质量会下降。
|
||||
|
||||
### Q: 与 clawdbot 的 web_fetch 有什么区别?
|
||||
|
||||
A: 本实现参考了 clawdbot 的设计,主要区别:
|
||||
- Python 实现(clawdbot 是 TypeScript)
|
||||
- 简化了一些高级特性(如 Firecrawl 集成)
|
||||
- 保留了核心的免费功能
|
||||
- 更容易集成到现有项目
|
||||
|
||||
### Q: 可以抓取需要登录的页面吗?
|
||||
|
||||
A: 当前版本不支持。如需抓取需要登录的页面,请使用 `browser` 工具。
|
||||
|
||||
## 参考
|
||||
|
||||
- [Mozilla Readability](https://github.com/mozilla/readability)
|
||||
- [Clawdbot Web Tools](https://github.com/moltbot/moltbot)
|
||||
@@ -1,3 +0,0 @@
|
||||
from .web_fetch import WebFetch
|
||||
|
||||
__all__ = ['WebFetch']
|
||||
@@ -1,47 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# WebFetch 工具依赖安装脚本
|
||||
|
||||
echo "=================================="
|
||||
echo "WebFetch 工具依赖安装"
|
||||
echo "=================================="
|
||||
echo ""
|
||||
|
||||
# 检查 Python 版本
|
||||
python_version=$(python3 --version 2>&1 | awk '{print $2}')
|
||||
echo "✓ Python 版本: $python_version"
|
||||
echo ""
|
||||
|
||||
# 安装基础依赖
|
||||
echo "📦 安装基础依赖..."
|
||||
python3 -m pip install requests
|
||||
|
||||
# 检查是否成功
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ requests 安装成功"
|
||||
else
|
||||
echo "❌ requests 安装失败"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# 安装推荐依赖
|
||||
echo "📦 安装推荐依赖(提升内容提取质量)..."
|
||||
python3 -m pip install readability-lxml html2text
|
||||
|
||||
# 检查是否成功
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ readability-lxml 和 html2text 安装成功"
|
||||
else
|
||||
echo "⚠️ 推荐依赖安装失败,但不影响基础功能"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=================================="
|
||||
echo "安装完成!"
|
||||
echo "=================================="
|
||||
echo ""
|
||||
echo "运行测试:"
|
||||
echo " python3 agent/tools/web_fetch/test_web_fetch.py"
|
||||
echo ""
|
||||
@@ -1,365 +0,0 @@
|
||||
"""
|
||||
Web Fetch tool - Fetch and extract readable content from URLs
|
||||
Supports HTML to Markdown/Text conversion using Mozilla's Readability
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class WebFetch(BaseTool):
|
||||
"""Tool for fetching and extracting readable content from web pages"""
|
||||
|
||||
name: str = "web_fetch"
|
||||
description: str = "Fetch and extract readable content from a URL (HTML → markdown/text). Use for lightweight page access without browser automation. Returns title, content, and metadata."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "HTTP or HTTPS URL to fetch"
|
||||
},
|
||||
"extract_mode": {
|
||||
"type": "string",
|
||||
"description": "Extraction mode: 'markdown' (default) or 'text'",
|
||||
"enum": ["markdown", "text"],
|
||||
"default": "markdown"
|
||||
},
|
||||
"max_chars": {
|
||||
"type": "integer",
|
||||
"description": "Maximum characters to return (default: 50000)",
|
||||
"minimum": 100,
|
||||
"default": 50000
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.timeout = self.config.get("timeout", 20)
|
||||
self.max_redirects = self.config.get("max_redirects", 3)
|
||||
self.user_agent = self.config.get(
|
||||
"user_agent",
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
# Setup session with retry strategy
|
||||
self.session = self._create_session()
|
||||
|
||||
# Check if readability-lxml is available
|
||||
self.readability_available = self._check_readability()
|
||||
|
||||
def _create_session(self) -> requests.Session:
|
||||
"""Create a requests session with retry strategy"""
|
||||
session = requests.Session()
|
||||
|
||||
# Retry strategy - handles failed requests, not redirects
|
||||
retry_strategy = Retry(
|
||||
total=3,
|
||||
backoff_factor=1,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
allowed_methods=["GET", "HEAD"]
|
||||
)
|
||||
|
||||
# HTTPAdapter handles retries; requests handles redirects via allow_redirects
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
# Set max redirects on session
|
||||
session.max_redirects = self.max_redirects
|
||||
|
||||
return session
|
||||
|
||||
def _check_readability(self) -> bool:
|
||||
"""Check if readability-lxml is available"""
|
||||
try:
|
||||
from readability import Document
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"readability-lxml not installed. Install with: pip install readability-lxml\n"
|
||||
"Falling back to basic HTML extraction."
|
||||
)
|
||||
return False
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute web fetch operation
|
||||
|
||||
:param args: Contains url, extract_mode, and max_chars parameters
|
||||
:return: Extracted content or error message
|
||||
"""
|
||||
url = args.get("url", "").strip()
|
||||
extract_mode = args.get("extract_mode", "markdown").lower()
|
||||
max_chars = args.get("max_chars", 50000)
|
||||
|
||||
if not url:
|
||||
return ToolResult.fail("Error: url parameter is required")
|
||||
|
||||
# Validate URL
|
||||
if not self._is_valid_url(url):
|
||||
return ToolResult.fail(f"Error: Invalid URL (must be http or https): {url}")
|
||||
|
||||
# Validate extract_mode
|
||||
if extract_mode not in ["markdown", "text"]:
|
||||
extract_mode = "markdown"
|
||||
|
||||
# Validate max_chars
|
||||
if not isinstance(max_chars, int) or max_chars < 100:
|
||||
max_chars = 50000
|
||||
|
||||
try:
|
||||
# Fetch the URL
|
||||
response = self._fetch_url(url)
|
||||
|
||||
# Extract content
|
||||
result = self._extract_content(
|
||||
html=response.text,
|
||||
url=response.url,
|
||||
status_code=response.status_code,
|
||||
content_type=response.headers.get("content-type", ""),
|
||||
extract_mode=extract_mode,
|
||||
max_chars=max_chars
|
||||
)
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
return ToolResult.fail(f"Error: Request timeout after {self.timeout} seconds")
|
||||
except requests.exceptions.TooManyRedirects:
|
||||
return ToolResult.fail(f"Error: Too many redirects (limit: {self.max_redirects})")
|
||||
except requests.exceptions.RequestException as e:
|
||||
return ToolResult.fail(f"Error fetching URL: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Web fetch error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: {str(e)}")
|
||||
|
||||
def _is_valid_url(self, url: str) -> bool:
|
||||
"""Validate URL format"""
|
||||
try:
|
||||
result = urlparse(url)
|
||||
return result.scheme in ["http", "https"] and bool(result.netloc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _fetch_url(self, url: str) -> requests.Response:
|
||||
"""
|
||||
Fetch URL with proper headers and error handling
|
||||
|
||||
:param url: URL to fetch
|
||||
:return: Response object
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": self.user_agent,
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
"Accept-Language": "en-US,en;q=0.9,zh-CN,zh;q=0.8",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
# Note: requests library handles redirects automatically
|
||||
# The max_redirects is set in the session's adapter (HTTPAdapter)
|
||||
response = self.session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
allow_redirects=True
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _extract_content(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract readable content from HTML
|
||||
|
||||
:param html: HTML content
|
||||
:param url: Original URL
|
||||
:param status_code: HTTP status code
|
||||
:param content_type: Content type header
|
||||
:param extract_mode: 'markdown' or 'text'
|
||||
:param max_chars: Maximum characters to return
|
||||
:return: Extracted content and metadata
|
||||
"""
|
||||
# Check content type
|
||||
if "text/html" not in content_type.lower():
|
||||
# Non-HTML content
|
||||
text = html[:max_chars]
|
||||
truncated = len(html) > max_chars
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"extractor": "raw",
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated,
|
||||
"message": f"Non-HTML content (type: {content_type})"
|
||||
}
|
||||
|
||||
# Extract readable content from HTML
|
||||
if self.readability_available:
|
||||
return self._extract_with_readability(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
else:
|
||||
return self._extract_basic(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
|
||||
def _extract_with_readability(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Extract content using Mozilla's Readability"""
|
||||
try:
|
||||
from readability import Document
|
||||
|
||||
# Parse with Readability
|
||||
doc = Document(html)
|
||||
title = doc.title()
|
||||
content_html = doc.summary()
|
||||
|
||||
# Convert to markdown or text
|
||||
if extract_mode == "markdown":
|
||||
text = self._html_to_markdown(content_html)
|
||||
else:
|
||||
text = self._html_to_text(content_html)
|
||||
|
||||
# Truncate if needed
|
||||
truncated = len(text) > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"extractor": "readability",
|
||||
"extract_mode": extract_mode,
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Readability extraction failed: {e}")
|
||||
# Fallback to basic extraction
|
||||
return self._extract_basic(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
|
||||
def _extract_basic(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Basic HTML extraction without Readability"""
|
||||
# Extract title
|
||||
title_match = re.search(r'<title[^>]*>(.*?)</title>', html, re.IGNORECASE | re.DOTALL)
|
||||
title = title_match.group(1).strip() if title_match else "Untitled"
|
||||
|
||||
# Remove script and style tags
|
||||
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
# Remove HTML tags
|
||||
text = re.sub(r'<[^>]+>', ' ', text)
|
||||
|
||||
# Clean up whitespace
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
|
||||
# Truncate if needed
|
||||
truncated = len(text) > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"extractor": "basic",
|
||||
"extract_mode": extract_mode,
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated,
|
||||
"warning": "Using basic extraction. Install readability-lxml for better results."
|
||||
}
|
||||
|
||||
def _html_to_markdown(self, html: str) -> str:
|
||||
"""Convert HTML to Markdown (basic implementation)"""
|
||||
try:
|
||||
# Try to use html2text if available
|
||||
import html2text
|
||||
h = html2text.HTML2Text()
|
||||
h.ignore_links = False
|
||||
h.ignore_images = False
|
||||
h.body_width = 0 # Don't wrap lines
|
||||
return h.handle(html)
|
||||
except ImportError:
|
||||
# Fallback to basic conversion
|
||||
return self._html_to_text(html)
|
||||
|
||||
def _html_to_text(self, html: str) -> str:
|
||||
"""Convert HTML to plain text"""
|
||||
# Remove script and style tags
|
||||
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
# Convert common tags to text equivalents
|
||||
text = re.sub(r'<br\s*/?>', '\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'<p[^>]*>', '\n\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'</p>', '', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'<h[1-6][^>]*>', '\n\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'</h[1-6]>', '\n', text, flags=re.IGNORECASE)
|
||||
|
||||
# Remove all other HTML tags
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
|
||||
# Decode HTML entities
|
||||
import html
|
||||
text = html.unescape(text)
|
||||
|
||||
# Clean up whitespace
|
||||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
text = text.strip()
|
||||
|
||||
return text
|
||||
|
||||
def close(self):
|
||||
"""Close the session"""
|
||||
if hasattr(self, 'session'):
|
||||
self.session.close()
|
||||
@@ -14,7 +14,7 @@ class Write(BaseTool):
|
||||
"""Tool for writing file content"""
|
||||
|
||||
name: str = "write"
|
||||
description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories."
|
||||
description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories. IMPORTANT: Single write should not exceed 10KB. For large files, create a skeleton first, then use edit to add content in chunks."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
|
||||
17
app.py
17
app.py
@@ -59,6 +59,23 @@ def run():
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
|
||||
start_channel(channel_name)
|
||||
|
||||
# 打印系统运行成功信息
|
||||
logger.info("")
|
||||
logger.info("=" * 50)
|
||||
if conf().get("agent", False):
|
||||
logger.info("✅ System started successfully!")
|
||||
logger.info("🐮 Cow Agent is running")
|
||||
logger.info(f" Channel: {channel_name}")
|
||||
logger.info(f" Model: {conf().get('model', 'unknown')}")
|
||||
logger.info(f" Workspace: {conf().get('agent_workspace', '~/cow')}")
|
||||
else:
|
||||
logger.info("✅ System started successfully!")
|
||||
logger.info("🤖 ChatBot is running")
|
||||
logger.info(f" Channel: {channel_name}")
|
||||
logger.info(f" Model: {conf().get('model', 'unknown')}")
|
||||
logger.info("=" * 50)
|
||||
logger.info("")
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
Agent Bridge - Integrates Agent system with existing COW bridge
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, List
|
||||
|
||||
from agent.protocol import Agent, LLMModel, LLMRequest
|
||||
from bot.openai_compatible_bot import OpenAICompatibleBot
|
||||
from models.openai_compatible_bot import OpenAICompatibleBot
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
@@ -180,6 +181,7 @@ class AgentBridge:
|
||||
self.agents = {} # session_id -> Agent instance mapping
|
||||
self.default_agent = None # For backward compatibility (no session_id)
|
||||
self.agent: Optional[Agent] = None
|
||||
self.scheduler_initialized = False
|
||||
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
|
||||
"""
|
||||
Create the super agent with COW integration
|
||||
@@ -228,12 +230,7 @@ class AgentBridge:
|
||||
|
||||
# Log skill loading details
|
||||
if agent.skill_manager:
|
||||
logger.info(f"[AgentBridge] SkillManager initialized:")
|
||||
logger.info(f"[AgentBridge] - Managed dir: {agent.skill_manager.managed_skills_dir}")
|
||||
logger.info(f"[AgentBridge] - Workspace dir: {agent.skill_manager.workspace_dir}")
|
||||
logger.info(f"[AgentBridge] - Total skills: {len(agent.skill_manager.skills)}")
|
||||
for skill_name in agent.skill_manager.skills.keys():
|
||||
logger.info(f"[AgentBridge] * {skill_name}")
|
||||
logger.debug(f"[AgentBridge] SkillManager initialized with {len(agent.skill_manager.skills)} skills")
|
||||
|
||||
return agent
|
||||
|
||||
@@ -268,6 +265,21 @@ class AgentBridge:
|
||||
# Get workspace from config
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
|
||||
# Migrate API keys from config.json to environment variables (if not already set)
|
||||
self._migrate_config_to_env(workspace_root)
|
||||
|
||||
# Load environment variables from secure .env file location
|
||||
env_file = os.path.expanduser("~/.cow/.env")
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(env_file, override=True)
|
||||
logger.info(f"[AgentBridge] Loaded environment variables from {env_file}")
|
||||
except ImportError:
|
||||
logger.warning("[AgentBridge] python-dotenv not installed, skipping .env file loading")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to load .env file: {e}")
|
||||
|
||||
# Initialize workspace and create template files
|
||||
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
|
||||
|
||||
@@ -283,14 +295,53 @@ class AgentBridge:
|
||||
from agent.memory import MemoryManager, MemoryConfig
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
workspace_root=workspace_root,
|
||||
embedding_provider="local", # Use local embedding (no API key needed)
|
||||
embedding_model="all-MiniLM-L6-v2"
|
||||
)
|
||||
# 从 config.json 读取 OpenAI 配置
|
||||
openai_api_key = conf().get("open_ai_api_key", "")
|
||||
openai_api_base = conf().get("open_ai_api_base", "")
|
||||
|
||||
# Create memory manager with the config
|
||||
memory_manager = MemoryManager(memory_config)
|
||||
# 尝试初始化 OpenAI embedding provider
|
||||
embedding_provider = None
|
||||
if openai_api_key:
|
||||
try:
|
||||
from agent.memory import create_embedding_provider
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model="text-embedding-3-small",
|
||||
api_key=openai_api_key,
|
||||
api_base=openai_api_base or "https://api.openai.com/v1"
|
||||
)
|
||||
logger.info(f"[AgentBridge] OpenAI embedding initialized")
|
||||
except Exception as embed_error:
|
||||
logger.warning(f"[AgentBridge] OpenAI embedding failed: {embed_error}")
|
||||
logger.info(f"[AgentBridge] Using keyword-only search")
|
||||
else:
|
||||
logger.info(f"[AgentBridge] No OpenAI API key, using keyword-only search")
|
||||
|
||||
# 创建 memory config
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
|
||||
# 创建 memory manager
|
||||
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
|
||||
# 初始化时执行一次 sync,确保数据库有数据
|
||||
import asyncio
|
||||
try:
|
||||
# 尝试在当前事件循环中执行
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建任务
|
||||
asyncio.create_task(memory_manager.sync())
|
||||
logger.info("[AgentBridge] Memory sync scheduled")
|
||||
else:
|
||||
# 如果没有运行的循环,直接执行
|
||||
loop.run_until_complete(memory_manager.sync())
|
||||
logger.info("[AgentBridge] Memory synced successfully")
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建新的
|
||||
asyncio.run(memory_manager.sync())
|
||||
logger.info("[AgentBridge] Memory synced successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Memory sync failed: {e}")
|
||||
|
||||
# Create memory tools
|
||||
memory_tools = [
|
||||
@@ -318,7 +369,15 @@ class AgentBridge:
|
||||
|
||||
for tool_name in tool_manager.tool_classes.keys():
|
||||
try:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
# Special handling for EnvConfig tool - pass agent_bridge reference
|
||||
if tool_name == "env_config":
|
||||
from agent.tools import EnvConfig
|
||||
tool = EnvConfig({
|
||||
"agent_bridge": self # Pass self reference for hot reload
|
||||
})
|
||||
else:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
if tool:
|
||||
# Apply workspace config to file operation tools
|
||||
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls']:
|
||||
@@ -326,12 +385,6 @@ class AgentBridge:
|
||||
tool.cwd = file_config.get("cwd", tool.cwd if hasattr(tool, 'cwd') else None)
|
||||
if 'memory_manager' in file_config:
|
||||
tool.memory_manager = file_config['memory_manager']
|
||||
# Apply API key for bocha_search tool
|
||||
elif tool_name == 'bocha_search':
|
||||
bocha_api_key = conf().get("bocha_api_key", "")
|
||||
if bocha_api_key:
|
||||
tool.config = {"bocha_api_key": bocha_api_key}
|
||||
tool.api_key = bocha_api_key
|
||||
tools.append(tool)
|
||||
logger.debug(f"[AgentBridge] Loaded tool: {tool_name}")
|
||||
except Exception as e:
|
||||
@@ -342,6 +395,36 @@ class AgentBridge:
|
||||
tools.extend(memory_tools)
|
||||
logger.info(f"[AgentBridge] Added {len(memory_tools)} memory tools")
|
||||
|
||||
# Initialize scheduler service (once)
|
||||
if not self.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import init_scheduler
|
||||
if init_scheduler(self):
|
||||
self.scheduler_initialized = True
|
||||
logger.info("[AgentBridge] Scheduler service initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to initialize scheduler: {e}")
|
||||
|
||||
# Inject scheduler dependencies into SchedulerTool instances
|
||||
if self.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
|
||||
from agent.tools import SchedulerTool
|
||||
|
||||
task_store = get_task_store()
|
||||
scheduler_service = get_scheduler_service()
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, SchedulerTool):
|
||||
tool.task_store = task_store
|
||||
tool.scheduler_service = scheduler_service
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = conf().get("channel_type", "unknown")
|
||||
logger.debug("[AgentBridge] Injected scheduler dependencies into SchedulerTool")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to inject scheduler dependencies: {e}")
|
||||
|
||||
logger.info(f"[AgentBridge] Loaded {len(tools)} tools: {[t.name for t in tools]}")
|
||||
|
||||
# Load context files (SOUL.md, USER.md, etc.)
|
||||
@@ -381,14 +464,19 @@ class AgentBridge:
|
||||
|
||||
logger.info("[AgentBridge] System prompt built successfully")
|
||||
|
||||
# Get cost control parameters from config
|
||||
max_steps = conf().get("agent_max_steps", 20)
|
||||
max_context_tokens = conf().get("agent_max_context_tokens", 50000)
|
||||
|
||||
# Create agent with configured tools and workspace
|
||||
agent = self.create_agent(
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
max_steps=50,
|
||||
max_steps=max_steps,
|
||||
output_mode="logger",
|
||||
workspace_dir=workspace_root, # Pass workspace to agent for skills loading
|
||||
enable_skills=True # Enable skills auto-loading
|
||||
enable_skills=True, # Enable skills auto-loading
|
||||
max_context_tokens=max_context_tokens
|
||||
)
|
||||
|
||||
# Attach memory manager to agent if available
|
||||
@@ -410,6 +498,24 @@ class AgentBridge:
|
||||
# Get workspace from config
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
|
||||
# Migrate API keys from config.json to environment variables (if not already set)
|
||||
self._migrate_config_to_env(workspace_root)
|
||||
|
||||
# Load environment variables from secure .env file location
|
||||
env_file = os.path.expanduser("~/.cow/.env")
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(env_file, override=True)
|
||||
logger.debug(f"[AgentBridge] Loaded environment variables from {env_file} for session {session_id}")
|
||||
except ImportError:
|
||||
logger.warning(f"[AgentBridge] python-dotenv not installed, skipping .env file loading for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to load .env file for session {session_id}: {e}")
|
||||
|
||||
# Migrate API keys from config.json to environment variables (if not already set)
|
||||
self._migrate_config_to_env(workspace_root)
|
||||
|
||||
# Initialize workspace
|
||||
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
|
||||
|
||||
@@ -420,23 +526,65 @@ class AgentBridge:
|
||||
memory_tools = []
|
||||
|
||||
try:
|
||||
from agent.memory import MemoryManager, MemoryConfig
|
||||
from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
workspace_root=workspace_root,
|
||||
embedding_provider="local",
|
||||
embedding_model="all-MiniLM-L6-v2"
|
||||
)
|
||||
# 从 config.json 读取 OpenAI 配置
|
||||
openai_api_key = conf().get("open_ai_api_key", "")
|
||||
openai_api_base = conf().get("open_ai_api_base", "")
|
||||
|
||||
# 尝试初始化 OpenAI embedding provider
|
||||
embedding_provider = None
|
||||
if openai_api_key:
|
||||
try:
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model="text-embedding-3-small",
|
||||
api_key=openai_api_key,
|
||||
api_base=openai_api_base or "https://api.openai.com/v1"
|
||||
)
|
||||
logger.debug(f"[AgentBridge] OpenAI embedding initialized for session {session_id}")
|
||||
except Exception as embed_error:
|
||||
logger.warning(f"[AgentBridge] OpenAI embedding failed for session {session_id}: {embed_error}")
|
||||
logger.info(f"[AgentBridge] Using keyword-only search for session {session_id}")
|
||||
else:
|
||||
logger.debug(f"[AgentBridge] No OpenAI API key, using keyword-only search for session {session_id}")
|
||||
|
||||
# 创建 memory config
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
|
||||
# 创建 memory manager
|
||||
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
|
||||
# 初始化时执行一次 sync,确保数据库有数据
|
||||
import asyncio
|
||||
try:
|
||||
# 尝试在当前事件循环中执行
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建任务
|
||||
asyncio.create_task(memory_manager.sync())
|
||||
logger.debug(f"[AgentBridge] Memory sync scheduled for session {session_id}")
|
||||
else:
|
||||
# 如果没有运行的循环,直接执行
|
||||
loop.run_until_complete(memory_manager.sync())
|
||||
logger.debug(f"[AgentBridge] Memory synced successfully for session {session_id}")
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建新的
|
||||
asyncio.run(memory_manager.sync())
|
||||
logger.debug(f"[AgentBridge] Memory synced successfully for session {session_id}")
|
||||
except Exception as sync_error:
|
||||
logger.warning(f"[AgentBridge] Memory sync failed for session {session_id}: {sync_error}")
|
||||
|
||||
memory_manager = MemoryManager(memory_config)
|
||||
memory_tools = [
|
||||
MemorySearchTool(memory_manager),
|
||||
MemoryGetTool(memory_manager)
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[AgentBridge] Memory system not available for session {session_id}: {e}")
|
||||
logger.warning(f"[AgentBridge] Memory system not available for session {session_id}: {e}")
|
||||
import traceback
|
||||
logger.warning(f"[AgentBridge] Memory init traceback: {traceback.format_exc()}")
|
||||
|
||||
# Load tools
|
||||
from agent.tools import ToolManager
|
||||
@@ -458,17 +606,42 @@ class AgentBridge:
|
||||
tool.cwd = file_config.get("cwd", tool.cwd if hasattr(tool, 'cwd') else None)
|
||||
if 'memory_manager' in file_config:
|
||||
tool.memory_manager = file_config['memory_manager']
|
||||
elif tool_name == 'bocha_search':
|
||||
bocha_api_key = conf().get("bocha_api_key", "")
|
||||
if bocha_api_key:
|
||||
tool.config = {"bocha_api_key": bocha_api_key}
|
||||
tool.api_key = bocha_api_key
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to load tool {tool_name} for session {session_id}: {e}")
|
||||
|
||||
if memory_tools:
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# Initialize scheduler service (once, if not already initialized)
|
||||
if not self.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import init_scheduler
|
||||
if init_scheduler(self):
|
||||
self.scheduler_initialized = True
|
||||
logger.debug(f"[AgentBridge] Scheduler service initialized for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to initialize scheduler for session {session_id}: {e}")
|
||||
|
||||
# Inject scheduler dependencies into SchedulerTool instances
|
||||
if self.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
|
||||
from agent.tools import SchedulerTool
|
||||
|
||||
task_store = get_task_store()
|
||||
scheduler_service = get_scheduler_service()
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, SchedulerTool):
|
||||
tool.task_store = task_store
|
||||
tool.scheduler_service = scheduler_service
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = conf().get("channel_type", "unknown")
|
||||
logger.debug(f"[AgentBridge] Injected scheduler dependencies for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to inject scheduler dependencies for session {session_id}: {e}")
|
||||
|
||||
# Load context files
|
||||
context_files = load_context_files(workspace_root)
|
||||
@@ -478,7 +651,7 @@ class AgentBridge:
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
skill_manager = SkillManager(workspace_dir=workspace_root)
|
||||
logger.info(f"[AgentBridge] Initialized SkillManager with {len(skill_manager.skills)} skills for session {session_id}")
|
||||
logger.debug(f"[AgentBridge] Initialized SkillManager with {len(skill_manager.skills)} skills for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to initialize SkillManager for session {session_id}: {e}")
|
||||
|
||||
@@ -543,15 +716,20 @@ class AgentBridge:
|
||||
if is_first:
|
||||
mark_conversation_started(workspace_root)
|
||||
|
||||
# Get cost control parameters from config
|
||||
max_steps = conf().get("agent_max_steps", 20)
|
||||
max_context_tokens = conf().get("agent_max_context_tokens", 50000)
|
||||
|
||||
# Create agent for this session
|
||||
agent = self.create_agent(
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
max_steps=50,
|
||||
max_steps=max_steps,
|
||||
output_mode="logger",
|
||||
workspace_dir=workspace_root,
|
||||
skill_manager=skill_manager,
|
||||
enable_skills=True
|
||||
enable_skills=True,
|
||||
max_context_tokens=max_context_tokens
|
||||
)
|
||||
|
||||
if memory_manager:
|
||||
@@ -586,12 +764,52 @@ class AgentBridge:
|
||||
if not agent:
|
||||
return Reply(ReplyType.ERROR, "Failed to initialize super agent")
|
||||
|
||||
# Use agent's run_stream method
|
||||
response = agent.run_stream(
|
||||
user_message=query,
|
||||
on_event=on_event,
|
||||
clear_history=clear_history
|
||||
)
|
||||
# Filter tools based on context
|
||||
original_tools = agent.tools
|
||||
filtered_tools = original_tools
|
||||
|
||||
# If this is a scheduled task execution, exclude scheduler tool to prevent recursion
|
||||
if context and context.get("is_scheduled_task"):
|
||||
filtered_tools = [tool for tool in agent.tools if tool.name != "scheduler"]
|
||||
agent.tools = filtered_tools
|
||||
logger.info(f"[AgentBridge] Scheduled task execution: excluded scheduler tool ({len(filtered_tools)}/{len(original_tools)} tools)")
|
||||
else:
|
||||
# Attach context to scheduler tool if present
|
||||
if context and agent.tools:
|
||||
for tool in agent.tools:
|
||||
if tool.name == "scheduler":
|
||||
try:
|
||||
from agent.tools.scheduler.integration import attach_scheduler_to_tool
|
||||
attach_scheduler_to_tool(tool, context)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}")
|
||||
break
|
||||
|
||||
try:
|
||||
# Use agent's run_stream method
|
||||
response = agent.run_stream(
|
||||
user_message=query,
|
||||
on_event=on_event,
|
||||
clear_history=clear_history
|
||||
)
|
||||
finally:
|
||||
# Restore original tools
|
||||
if context and context.get("is_scheduled_task"):
|
||||
agent.tools = original_tools
|
||||
|
||||
# Check if there are files to send (from read tool)
|
||||
if hasattr(agent, 'stream_executor') and hasattr(agent.stream_executor, 'files_to_send'):
|
||||
files_to_send = agent.stream_executor.files_to_send
|
||||
if files_to_send:
|
||||
# Send the first file (for now, handle one file at a time)
|
||||
file_info = files_to_send[0]
|
||||
logger.info(f"[AgentBridge] Sending file: {file_info.get('path')}")
|
||||
|
||||
# Clear files_to_send for next request
|
||||
agent.stream_executor.files_to_send = []
|
||||
|
||||
# Return file reply based on file type
|
||||
return self._create_file_reply(file_info, response, context)
|
||||
|
||||
return Reply(ReplyType.TEXT, response)
|
||||
|
||||
@@ -599,6 +817,120 @@ class AgentBridge:
|
||||
logger.error(f"Agent reply error: {e}")
|
||||
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")
|
||||
|
||||
def _create_file_reply(self, file_info: dict, text_response: str, context: Context = None) -> Reply:
|
||||
"""
|
||||
Create a reply for sending files
|
||||
|
||||
Args:
|
||||
file_info: File metadata from read tool
|
||||
text_response: Text response from agent
|
||||
context: Context object
|
||||
|
||||
Returns:
|
||||
Reply object for file sending
|
||||
"""
|
||||
file_type = file_info.get("file_type", "file")
|
||||
file_path = file_info.get("path")
|
||||
|
||||
# For images, use IMAGE_URL type (channel will handle upload)
|
||||
if file_type == "image":
|
||||
# Convert local path to file:// URL for channel processing
|
||||
file_url = f"file://{file_path}"
|
||||
logger.info(f"[AgentBridge] Sending image: {file_url}")
|
||||
reply = Reply(ReplyType.IMAGE_URL, file_url)
|
||||
# Attach text message if present (for channels that support text+image)
|
||||
if text_response:
|
||||
reply.text_content = text_response # Store accompanying text
|
||||
return reply
|
||||
|
||||
# For documents (PDF, Excel, Word, PPT), use FILE type
|
||||
if file_type == "document":
|
||||
file_url = f"file://{file_path}"
|
||||
logger.info(f"[AgentBridge] Sending document: {file_url}")
|
||||
reply = Reply(ReplyType.FILE, file_url)
|
||||
reply.file_name = file_info.get("file_name", os.path.basename(file_path))
|
||||
return reply
|
||||
|
||||
# For other files (video, audio), we need channel-specific handling
|
||||
# For now, return text with file info
|
||||
# TODO: Implement video/audio sending when channel supports it
|
||||
message = text_response or file_info.get("message", "文件已准备")
|
||||
message += f"\n\n[文件: {file_info.get('file_name', file_path)}]"
|
||||
return Reply(ReplyType.TEXT, message)
|
||||
|
||||
def _migrate_config_to_env(self, workspace_root: str):
|
||||
"""
|
||||
Migrate API keys from config.json to .env file if not already set
|
||||
|
||||
Args:
|
||||
workspace_root: Workspace directory path (not used, kept for compatibility)
|
||||
"""
|
||||
from config import conf
|
||||
import os
|
||||
|
||||
# Mapping from config.json keys to environment variable names
|
||||
key_mapping = {
|
||||
"open_ai_api_key": "OPENAI_API_KEY",
|
||||
"open_ai_api_base": "OPENAI_API_BASE",
|
||||
"gemini_api_key": "GEMINI_API_KEY",
|
||||
"claude_api_key": "CLAUDE_API_KEY",
|
||||
"linkai_api_key": "LINKAI_API_KEY",
|
||||
}
|
||||
|
||||
# Use fixed secure location for .env file
|
||||
env_file = os.path.expanduser("~/.cow/.env")
|
||||
|
||||
# Read existing env vars from .env file
|
||||
existing_env_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
with open(env_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, _ = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to read .env file: {e}")
|
||||
|
||||
# Check which keys need to be migrated
|
||||
keys_to_migrate = {}
|
||||
for config_key, env_key in key_mapping.items():
|
||||
# Skip if already in .env file
|
||||
if env_key in existing_env_vars:
|
||||
continue
|
||||
|
||||
# Get value from config.json
|
||||
value = conf().get(config_key, "")
|
||||
if value and value.strip(): # Only migrate non-empty values
|
||||
keys_to_migrate[env_key] = value.strip()
|
||||
|
||||
# Log summary if there are keys to skip
|
||||
if existing_env_vars:
|
||||
logger.debug(f"[AgentBridge] {len(existing_env_vars)} env vars already in .env")
|
||||
|
||||
# Write new keys to .env file
|
||||
if keys_to_migrate:
|
||||
try:
|
||||
# Ensure ~/.cow directory and .env file exist
|
||||
env_dir = os.path.dirname(env_file)
|
||||
if not os.path.exists(env_dir):
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
if not os.path.exists(env_file):
|
||||
open(env_file, 'a').close()
|
||||
|
||||
# Append new keys
|
||||
with open(env_file, 'a', encoding='utf-8') as f:
|
||||
f.write('\n# Auto-migrated from config.json\n')
|
||||
for key, value in keys_to_migrate.items():
|
||||
f.write(f'{key}={value}\n')
|
||||
# Also set in current process
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[AgentBridge] Migrated {len(keys_to_migrate)} API keys from config.json to .env: {list(keys_to_migrate.keys())}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to migrate API keys: {e}")
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""
|
||||
Clear a specific session's agent and conversation history
|
||||
@@ -614,4 +946,43 @@ class AgentBridge:
|
||||
"""Clear all agent sessions"""
|
||||
logger.info(f"[AgentBridge] Clearing all sessions ({len(self.agents)} total)")
|
||||
self.agents.clear()
|
||||
self.default_agent = None
|
||||
self.default_agent = None
|
||||
|
||||
def refresh_all_skills(self) -> int:
|
||||
"""
|
||||
Refresh skills in all agent instances after environment variable changes.
|
||||
This allows hot-reload of skills without restarting the agent.
|
||||
|
||||
Returns:
|
||||
Number of agent instances refreshed
|
||||
"""
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from config import conf
|
||||
|
||||
# Reload environment variables from .env file
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
env_file = os.path.join(workspace_root, '.env')
|
||||
|
||||
if os.path.exists(env_file):
|
||||
load_dotenv(env_file, override=True)
|
||||
logger.info(f"[AgentBridge] Reloaded environment variables from {env_file}")
|
||||
|
||||
refreshed_count = 0
|
||||
|
||||
# Refresh default agent
|
||||
if self.default_agent and hasattr(self.default_agent, 'skill_manager'):
|
||||
self.default_agent.skill_manager.refresh_skills()
|
||||
refreshed_count += 1
|
||||
logger.info("[AgentBridge] Refreshed skills in default agent")
|
||||
|
||||
# Refresh all session agents
|
||||
for session_id, agent in self.agents.items():
|
||||
if hasattr(agent, 'skill_manager'):
|
||||
agent.skill_manager.refresh_skills()
|
||||
refreshed_count += 1
|
||||
|
||||
if refreshed_count > 0:
|
||||
logger.info(f"[AgentBridge] Refreshed skills in {refreshed_count} agent instance(s)")
|
||||
|
||||
return refreshed_count
|
||||
@@ -1,4 +1,4 @@
|
||||
from bot.bot_factory import create_bot
|
||||
from models.bot_factory import create_bot
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from common import const
|
||||
|
||||
@@ -64,15 +64,22 @@ class ChatChannel(Channel):
|
||||
check_contain(group_name, group_name_keyword_white_list),
|
||||
]
|
||||
):
|
||||
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
|
||||
session_id = cmsg.actual_user_id
|
||||
if any(
|
||||
[
|
||||
group_name in group_chat_in_one_session,
|
||||
"ALL_GROUP" in group_chat_in_one_session,
|
||||
]
|
||||
):
|
||||
# Check global group_shared_session config first
|
||||
group_shared_session = conf().get("group_shared_session", True)
|
||||
if group_shared_session:
|
||||
# All users in the group share the same session
|
||||
session_id = group_id
|
||||
else:
|
||||
# Check group-specific whitelist (legacy behavior)
|
||||
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
|
||||
session_id = cmsg.actual_user_id
|
||||
if any(
|
||||
[
|
||||
group_name in group_chat_in_one_session,
|
||||
"ALL_GROUP" in group_chat_in_one_session,
|
||||
]
|
||||
):
|
||||
session_id = group_id
|
||||
else:
|
||||
logger.debug(f"No need reply, groupName not in whitelist, group_name={group_name}")
|
||||
return None
|
||||
@@ -166,11 +173,11 @@ class ChatChannel(Channel):
|
||||
def _handle(self, context: Context):
|
||||
if context is None or not context.content:
|
||||
return
|
||||
logger.debug("[chat_channel] ready to handle context: {}".format(context))
|
||||
logger.debug("[chat_channel] handling context: {}".format(context))
|
||||
# reply的构建步骤
|
||||
reply = self._generate_reply(context)
|
||||
|
||||
logger.debug("[chat_channel] ready to decorate reply: {}".format(reply))
|
||||
logger.debug("[chat_channel] decorating reply: {}".format(reply))
|
||||
|
||||
# reply的包装步骤
|
||||
if reply and reply.content:
|
||||
@@ -188,7 +195,7 @@ class ChatChannel(Channel):
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass():
|
||||
logger.debug("[chat_channel] ready to handle context: type={}, content={}".format(context.type, context.content))
|
||||
logger.debug("[chat_channel] type={}, content={}".format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
||||
context["channel"] = e_context["channel"]
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
@@ -282,7 +289,100 @@ class ChatChannel(Channel):
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug("[chat_channel] ready to send reply: {}, context: {}".format(reply, context))
|
||||
logger.debug("[chat_channel] sending reply: {}, context: {}".format(reply, context))
|
||||
|
||||
# 如果是文本回复,尝试提取并发送图片
|
||||
if reply.type == ReplyType.TEXT:
|
||||
self._extract_and_send_images(reply, context)
|
||||
# 如果是图片回复但带有文本内容,先发文本再发图片
|
||||
elif reply.type == ReplyType.IMAGE_URL and hasattr(reply, 'text_content') and reply.text_content:
|
||||
# 先发送文本
|
||||
text_reply = Reply(ReplyType.TEXT, reply.text_content)
|
||||
self._send(text_reply, context)
|
||||
# 短暂延迟后发送图片
|
||||
time.sleep(0.3)
|
||||
self._send(reply, context)
|
||||
else:
|
||||
self._send(reply, context)
|
||||
|
||||
def _extract_and_send_images(self, reply: Reply, context: Context):
|
||||
"""
|
||||
从文本回复中提取图片/视频URL并单独发送
|
||||
支持格式:[图片: /path/to/image.png], [视频: /path/to/video.mp4], , <img src="url">
|
||||
最多发送5个媒体文件
|
||||
"""
|
||||
content = reply.content
|
||||
media_items = [] # [(url, type), ...]
|
||||
|
||||
# 正则提取各种格式的媒体URL
|
||||
patterns = [
|
||||
(r'\[图片:\s*([^\]]+)\]', 'image'), # [图片: /path/to/image.png]
|
||||
(r'\[视频:\s*([^\]]+)\]', 'video'), # [视频: /path/to/video.mp4]
|
||||
(r'!\[.*?\]\(([^\)]+)\)', 'image'), #  - 默认图片
|
||||
(r'<img[^>]+src=["\']([^"\']+)["\']', 'image'), # <img src="url">
|
||||
(r'<video[^>]+src=["\']([^"\']+)["\']', 'video'), # <video src="url">
|
||||
(r'https?://[^\s]+\.(?:jpg|jpeg|png|gif|webp)', 'image'), # 直接的图片URL
|
||||
(r'https?://[^\s]+\.(?:mp4|avi|mov|wmv|flv)', 'video'), # 直接的视频URL
|
||||
]
|
||||
|
||||
for pattern, media_type in patterns:
|
||||
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||
for match in matches:
|
||||
media_items.append((match, media_type))
|
||||
|
||||
# 去重(保持顺序)并限制最多5个
|
||||
seen = set()
|
||||
unique_items = []
|
||||
for url, mtype in media_items:
|
||||
if url not in seen:
|
||||
seen.add(url)
|
||||
unique_items.append((url, mtype))
|
||||
media_items = unique_items[:5]
|
||||
|
||||
if media_items:
|
||||
logger.info(f"[chat_channel] Extracted {len(media_items)} media item(s) from reply")
|
||||
|
||||
# 先发送文本(保持原文本不变)
|
||||
logger.info(f"[chat_channel] Sending text content before media: {reply.content[:100]}...")
|
||||
self._send(reply, context)
|
||||
logger.info(f"[chat_channel] Text sent, now sending {len(media_items)} media item(s)")
|
||||
|
||||
# 然后逐个发送媒体文件
|
||||
for i, (url, media_type) in enumerate(media_items):
|
||||
try:
|
||||
# 判断是本地文件还是URL
|
||||
if url.startswith(('http://', 'https://')):
|
||||
# 网络资源
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型发送
|
||||
media_reply = Reply(ReplyType.FILE, url)
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, url)
|
||||
elif os.path.exists(url):
|
||||
# 本地文件
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.FILE, f"file://{url}")
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, f"file://{url}")
|
||||
else:
|
||||
logger.warning(f"[chat_channel] Media file not found or invalid URL: {url}")
|
||||
continue
|
||||
|
||||
# 发送媒体文件(添加小延迟避免频率限制)
|
||||
if i > 0:
|
||||
time.sleep(0.5)
|
||||
self._send(media_reply, context)
|
||||
logger.info(f"[chat_channel] Sent {media_type} {i+1}/{len(media_items)}: {url[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[chat_channel] Failed to send {media_type} {url}: {e}")
|
||||
else:
|
||||
# 没有媒体文件,正常发送文本
|
||||
self._send(reply, context)
|
||||
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt=0):
|
||||
|
||||
@@ -8,7 +8,9 @@ import copy
|
||||
import json
|
||||
# -*- coding=utf-8 -*-
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
|
||||
import dingtalk_stream
|
||||
from dingtalk_stream import AckMessage
|
||||
@@ -101,22 +103,376 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.logger = self.setup_logger()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
logger.info("[DingTalk] client_id={}, client_secret={} ".format(
|
||||
logger.debug("[DingTalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
# 单聊无需前缀
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
# Access token cache
|
||||
self._access_token = None
|
||||
self._access_token_expires_at = 0
|
||||
# Robot code cache (extracted from incoming messages)
|
||||
self._robot_code = None
|
||||
|
||||
def startup(self):
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
logger.info("[DingTalk] ✅ Stream connected, ready to receive messages")
|
||||
client.start_forever()
|
||||
|
||||
def get_access_token(self):
|
||||
"""
|
||||
获取企业内部应用的 access_token
|
||||
文档: https://open.dingtalk.com/document/orgapp/obtain-orgapp-token
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 如果 token 还没过期,直接返回缓存的 token
|
||||
if self._access_token and current_time < self._access_token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
# 获取新的 access_token
|
||||
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"appKey": self.dingtalk_client_id,
|
||||
"appSecret": self.dingtalk_client_secret
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200 and "accessToken" in result:
|
||||
self._access_token = result["accessToken"]
|
||||
# Token 有效期为 2 小时,提前 5 分钟刷新
|
||||
self._access_token_expires_at = current_time + result.get("expireIn", 7200) - 300
|
||||
logger.info("[DingTalk] Access token refreshed successfully")
|
||||
return self._access_token
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to get access token: {result}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error getting access token: {e}")
|
||||
return None
|
||||
|
||||
def send_single_message(self, user_id: str, content: str, robot_code: str) -> bool:
|
||||
"""
|
||||
Send message to single user (private chat)
|
||||
API: https://open.dingtalk.com/document/orgapp/chatbots-send-one-on-one-chat-messages-in-batches
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Failed to send single message: Access token not available.")
|
||||
return False
|
||||
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] Cannot send single message: robot_code is required")
|
||||
return False
|
||||
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"content": content}),
|
||||
"msgKey": "sampleText",
|
||||
"userIds": [user_id],
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
logger.info(f"[DingTalk] Sending single message to user {user_id} with robot_code {robot_code}")
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200 and result.get("processQueryKey"):
|
||||
logger.info(f"[DingTalk] Single message sent successfully to {user_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to send single message: {result}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error sending single message: {e}")
|
||||
return False
|
||||
|
||||
def send_group_message(self, conversation_id: str, content: str, robot_code: str = None):
|
||||
"""
|
||||
主动发送群消息
|
||||
文档: https://open.dingtalk.com/document/orgapp/the-robot-sends-a-group-message
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID (openConversationId)
|
||||
content: 消息内容
|
||||
robot_code: 机器人编码,默认使用 dingtalk_client_id
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot send group message: no access token")
|
||||
return False
|
||||
|
||||
# Validate robot_code
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] Cannot send group message: robot_code is required")
|
||||
return False
|
||||
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"content": content}),
|
||||
"msgKey": "sampleText",
|
||||
"openConversationId": conversation_id,
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(f"[DingTalk] Group message sent successfully to {conversation_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to send group message: {result}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error sending group message: {e}")
|
||||
return False
|
||||
|
||||
def upload_media(self, file_path: str, media_type: str = "image") -> str:
|
||||
"""
|
||||
上传媒体文件到钉钉
|
||||
|
||||
Args:
|
||||
file_path: 本地文件路径或URL
|
||||
media_type: 媒体类型 (image, video, voice, file)
|
||||
|
||||
Returns:
|
||||
media_id,如果上传失败返回 None
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot upload media: no access token")
|
||||
return None
|
||||
|
||||
# 处理 file:// URL
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
|
||||
# 如果是 HTTP URL,先下载
|
||||
if file_path.startswith("http://") or file_path.startswith("https://"):
|
||||
try:
|
||||
import uuid
|
||||
response = requests.get(file_path, timeout=(5, 60))
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[DingTalk] Failed to download file from URL: {file_path}")
|
||||
return None
|
||||
|
||||
# 保存到临时文件
|
||||
file_name = os.path.basename(file_path) or f"media_{uuid.uuid4()}"
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
temp_file = os.path.join(tmp_dir, file_name)
|
||||
|
||||
with open(temp_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
file_path = temp_file
|
||||
logger.info(f"[DingTalk] Downloaded file to {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error downloading file: {e}")
|
||||
return None
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"[DingTalk] File not found: {file_path}")
|
||||
return None
|
||||
|
||||
# 上传到钉钉
|
||||
# 钉钉上传媒体文件 API: https://open.dingtalk.com/document/orgapp/upload-media-files
|
||||
url = "https://oapi.dingtalk.com/media/upload"
|
||||
params = {
|
||||
"access_token": access_token,
|
||||
"type": media_type
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"media": (os.path.basename(file_path), f)}
|
||||
response = requests.post(url, params=params, files=files, timeout=(5, 60))
|
||||
result = response.json()
|
||||
|
||||
if result.get("errcode") == 0:
|
||||
media_id = result.get("media_id")
|
||||
logger.info(f"[DingTalk] Media uploaded successfully, media_id={media_id}")
|
||||
return media_id
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to upload media: {result}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error uploading media: {e}")
|
||||
return None
|
||||
|
||||
def send_image_with_media_id(self, access_token: str, media_id: str, incoming_message, is_group: bool) -> bool:
|
||||
"""
|
||||
发送图片消息(使用 media_id)
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
media_id: 媒体ID
|
||||
incoming_message: 钉钉消息对象
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
msg_param = {
|
||||
"photoURL": media_id # 钉钉图片消息使用 photoURL 字段
|
||||
}
|
||||
|
||||
body = {
|
||||
"robotCode": incoming_message.robot_code,
|
||||
"msgKey": "sampleImageMsg",
|
||||
"msgParam": json.dumps(msg_param),
|
||||
}
|
||||
|
||||
if is_group:
|
||||
# 群聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
body["openConversationId"] = incoming_message.conversation_id
|
||||
else:
|
||||
# 单聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
body["userIds"] = [incoming_message.sender_staff_id]
|
||||
|
||||
try:
|
||||
response = requests.post(url=url, headers=headers, json=body, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
logger.info(f"[DingTalk] Image send result: {response.text}")
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Send image error: {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Send image exception: {e}")
|
||||
return False
|
||||
|
||||
def send_image_message(self, receiver: str, media_id: str, is_group: bool, robot_code: str) -> bool:
|
||||
"""
|
||||
发送图片消息
|
||||
|
||||
Args:
|
||||
receiver: 接收者ID (user_id 或 conversation_id)
|
||||
media_id: 媒体ID
|
||||
is_group: 是否为群聊
|
||||
robot_code: 机器人编码
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot send image: no access token")
|
||||
return False
|
||||
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] Cannot send image: robot_code is required")
|
||||
return False
|
||||
|
||||
if is_group:
|
||||
# 发送群聊图片
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"mediaId": media_id}),
|
||||
"msgKey": "sampleImageMsg",
|
||||
"openConversationId": receiver,
|
||||
"robotCode": robot_code
|
||||
}
|
||||
else:
|
||||
# 发送单聊图片
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"mediaId": media_id}),
|
||||
"msgKey": "sampleImageMsg",
|
||||
"userIds": [receiver],
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(f"[DingTalk] Image message sent successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to send image message: {result}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error sending image message: {e}")
|
||||
return False
|
||||
|
||||
def get_image_download_url(self, download_code: str) -> str:
|
||||
"""
|
||||
获取图片下载地址
|
||||
返回一个特殊的 URL 格式:dingtalk://download/{robot_code}:{download_code}
|
||||
后续会在 download_image_file 中使用新版 API 下载
|
||||
"""
|
||||
# 获取 robot_code
|
||||
if not hasattr(self, '_robot_code_cache'):
|
||||
self._robot_code_cache = None
|
||||
|
||||
robot_code = self._robot_code_cache
|
||||
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] robot_code not available for image download")
|
||||
return None
|
||||
|
||||
# 返回一个特殊的 URL,包含 robot_code 和 download_code
|
||||
logger.info(f"[DingTalk] Successfully got image download URL for code: {download_code}")
|
||||
return f"dingtalk://download/{robot_code}:{download_code}"
|
||||
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
|
||||
# 缓存 robot_code,用于后续图片下载
|
||||
if hasattr(incoming_message, 'robot_code'):
|
||||
self._robot_code_cache = incoming_message.robot_code
|
||||
|
||||
# Debug: 打印完整的 event 数据
|
||||
logger.debug(f"[DingTalk] ===== Incoming Message Debug =====")
|
||||
logger.debug(f"[DingTalk] callback.data keys: {callback.data.keys() if hasattr(callback.data, 'keys') else 'N/A'}")
|
||||
logger.debug(f"[DingTalk] incoming_message attributes: {dir(incoming_message)}")
|
||||
logger.debug(f"[DingTalk] robot_code: {getattr(incoming_message, 'robot_code', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] chatbot_corp_id: {getattr(incoming_message, 'chatbot_corp_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] chatbot_user_id: {getattr(incoming_message, 'chatbot_user_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] conversation_id: {getattr(incoming_message, 'conversation_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] Raw callback.data: {callback.data}")
|
||||
logger.debug(f"[DingTalk] =====================================")
|
||||
|
||||
image_download_handler = self # 传入方法所在的类实例
|
||||
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
|
||||
|
||||
@@ -126,7 +482,8 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.handle_single(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(f"dingtalk process error={e}")
|
||||
logger.error(f"[DingTalk] process error: {e}")
|
||||
logger.exception(e) # 打印完整堆栈跟踪
|
||||
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
|
||||
|
||||
@time_checker
|
||||
@@ -145,6 +502,43 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
|
||||
# 处理文件缓存逻辑
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
# 单聊的 session_id 就是 sender_id
|
||||
session_id = cmsg.from_user_id
|
||||
|
||||
# 如果是单张图片消息,缓存起来
|
||||
if cmsg.ctype == ContextType.IMAGE:
|
||||
if hasattr(cmsg, 'image_path') and cmsg.image_path:
|
||||
file_cache.add(session_id, cmsg.image_path, file_type='image')
|
||||
logger.info(f"[DingTalk] Image cached for session {session_id}, waiting for user query...")
|
||||
# 单张图片不直接处理,等待用户提问
|
||||
return
|
||||
|
||||
# 如果是文本消息,检查是否有缓存的文件
|
||||
if cmsg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
# 将缓存的文件附加到文本消息中
|
||||
file_refs = []
|
||||
for file_info in cached_files:
|
||||
file_path = file_info['path']
|
||||
file_type = file_info['type']
|
||||
if file_type == 'image':
|
||||
file_refs.append(f"[图片: {file_path}]")
|
||||
elif file_type == 'video':
|
||||
file_refs.append(f"[视频: {file_path}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {file_path}]")
|
||||
|
||||
cmsg.content = cmsg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[DingTalk] Attached {len(cached_files)} cached file(s) to user query")
|
||||
# 清除缓存
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
@@ -166,6 +560,46 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
|
||||
# 处理文件缓存逻辑
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
# 群聊的 session_id
|
||||
if conf().get("group_shared_session", True):
|
||||
session_id = cmsg.other_user_id # conversation_id
|
||||
else:
|
||||
session_id = cmsg.from_user_id + "_" + cmsg.other_user_id
|
||||
|
||||
# 如果是单张图片消息,缓存起来
|
||||
if cmsg.ctype == ContextType.IMAGE:
|
||||
if hasattr(cmsg, 'image_path') and cmsg.image_path:
|
||||
file_cache.add(session_id, cmsg.image_path, file_type='image')
|
||||
logger.info(f"[DingTalk] Image cached for session {session_id}, waiting for user query...")
|
||||
# 单张图片不直接处理,等待用户提问
|
||||
return
|
||||
|
||||
# 如果是文本消息,检查是否有缓存的文件
|
||||
if cmsg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
# 将缓存的文件附加到文本消息中
|
||||
file_refs = []
|
||||
for file_info in cached_files:
|
||||
file_path = file_info['path']
|
||||
file_type = file_info['type']
|
||||
if file_type == 'image':
|
||||
file_refs.append(f"[图片: {file_path}]")
|
||||
elif file_type == 'video':
|
||||
file_refs.append(f"[视频: {file_path}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {file_path}]")
|
||||
|
||||
cmsg.content = cmsg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[DingTalk] Attached {len(cached_files)} cached file(s) to user query")
|
||||
# 清除缓存
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
context['no_need_at'] = True
|
||||
if context:
|
||||
@@ -173,32 +607,228 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
logger.info(f"[DingTalk] send() called with reply.type={reply.type}, content_length={len(str(reply.content))}")
|
||||
receiver = context["receiver"]
|
||||
isgroup = context.kwargs['msg'].is_group
|
||||
incoming_message = context.kwargs['msg'].incoming_message
|
||||
|
||||
if conf().get("dingtalk_card_enabled"):
|
||||
logger.info("[Dingtalk] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
def reply_with_text():
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
def reply_with_at_text():
|
||||
self.reply_text("📢 您有一条新的消息,请查看。", incoming_message)
|
||||
def reply_with_ai_markdown():
|
||||
button_list, markdown_content = self.generate_button_markdown_content(context, reply)
|
||||
self.reply_ai_markdown_button(incoming_message, markdown_content, button_list, "", "📌 内容由AI生成", "",[incoming_message.sender_staff_id])
|
||||
|
||||
if reply.type in [ReplyType.IMAGE_URL, ReplyType.IMAGE, ReplyType.TEXT]:
|
||||
if isgroup:
|
||||
reply_with_ai_markdown()
|
||||
reply_with_at_text()
|
||||
else:
|
||||
reply_with_ai_markdown()
|
||||
|
||||
# Check if msg exists (for scheduled tasks, msg might be None)
|
||||
msg = context.kwargs.get('msg')
|
||||
if msg is None:
|
||||
# 定时任务场景:使用主动发送 API
|
||||
is_group = context.get("isgroup", False)
|
||||
logger.info(f"[DingTalk] Sending scheduled task message to {receiver} (is_group={is_group})")
|
||||
|
||||
# 使用缓存的 robot_code 或配置的值
|
||||
robot_code = self._robot_code or conf().get("dingtalk_robot_code")
|
||||
logger.info(f"[DingTalk] Using robot_code: {robot_code}, cached: {self._robot_code}, config: {conf().get('dingtalk_robot_code')}")
|
||||
|
||||
if not robot_code:
|
||||
logger.error(f"[DingTalk] Cannot send scheduled task: robot_code not available. Please send at least one message to the bot first, or configure dingtalk_robot_code in config.json")
|
||||
return
|
||||
|
||||
# 根据是否群聊选择不同的 API
|
||||
if is_group:
|
||||
success = self.send_group_message(receiver, reply.content, robot_code)
|
||||
else:
|
||||
# 暂不支持其它类型消息回复
|
||||
reply_with_text()
|
||||
else:
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
# 单聊场景:尝试从 context 中获取 dingtalk_sender_staff_id
|
||||
sender_staff_id = context.get("dingtalk_sender_staff_id")
|
||||
if not sender_staff_id:
|
||||
logger.error(f"[DingTalk] Cannot send single chat scheduled message: sender_staff_id not available in context")
|
||||
return
|
||||
|
||||
logger.info(f"[DingTalk] Sending single message to staff_id: {sender_staff_id}")
|
||||
success = self.send_single_message(sender_staff_id, reply.content, robot_code)
|
||||
|
||||
if not success:
|
||||
logger.error(f"[DingTalk] Failed to send scheduled task message")
|
||||
return
|
||||
|
||||
# 从正常消息中提取并缓存 robot_code
|
||||
if hasattr(msg, 'robot_code'):
|
||||
robot_code = msg.robot_code
|
||||
if robot_code and robot_code != self._robot_code:
|
||||
self._robot_code = robot_code
|
||||
logger.info(f"[DingTalk] Cached robot_code: {robot_code}")
|
||||
|
||||
isgroup = msg.is_group
|
||||
incoming_message = msg.incoming_message
|
||||
robot_code = self._robot_code or conf().get("dingtalk_robot_code")
|
||||
|
||||
# 处理图片和视频发送
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
logger.info(f"[DingTalk] Sending image: {reply.content}")
|
||||
|
||||
# 如果有附加的文本内容,先发送文本
|
||||
if hasattr(reply, 'text_content') and reply.text_content:
|
||||
self.reply_text(reply.text_content, incoming_message)
|
||||
import time
|
||||
time.sleep(0.3) # 短暂延迟,确保文本先到达
|
||||
|
||||
media_id = self.upload_media(reply.content, media_type="image")
|
||||
if media_id:
|
||||
# 使用主动发送 API 发送图片
|
||||
access_token = self.get_access_token()
|
||||
if access_token:
|
||||
success = self.send_image_with_media_id(
|
||||
access_token,
|
||||
media_id,
|
||||
incoming_message,
|
||||
isgroup
|
||||
)
|
||||
if not success:
|
||||
logger.error("[DingTalk] Failed to send image message")
|
||||
self.reply_text("抱歉,图片发送失败", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Cannot get access token")
|
||||
self.reply_text("抱歉,图片发送失败(无法获取token)", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Failed to upload image")
|
||||
self.reply_text("抱歉,图片上传失败", incoming_message)
|
||||
return
|
||||
|
||||
elif reply.type == ReplyType.FILE:
|
||||
# 如果有附加的文本内容,先发送文本
|
||||
if hasattr(reply, 'text_content') and reply.text_content:
|
||||
self.reply_text(reply.text_content, incoming_message)
|
||||
import time
|
||||
time.sleep(0.3) # 短暂延迟,确保文本先到达
|
||||
|
||||
# 判断是否为视频文件
|
||||
file_path = reply.content
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
|
||||
is_video = file_path.lower().endswith(('.mp4', '.avi', '.mov', '.wmv', '.flv'))
|
||||
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot get access token")
|
||||
self.reply_text("抱歉,文件发送失败(无法获取token)", incoming_message)
|
||||
return
|
||||
|
||||
if is_video:
|
||||
logger.info(f"[DingTalk] Sending video: {reply.content}")
|
||||
media_id = self.upload_media(reply.content, media_type="video")
|
||||
if media_id:
|
||||
# 发送视频消息
|
||||
msg_param = {
|
||||
"duration": "30", # TODO: 获取实际视频时长
|
||||
"videoMediaId": media_id,
|
||||
"videoType": "mp4",
|
||||
"height": "400",
|
||||
"width": "600",
|
||||
}
|
||||
success = self._send_file_message(
|
||||
access_token,
|
||||
incoming_message,
|
||||
"sampleVideo",
|
||||
msg_param,
|
||||
isgroup
|
||||
)
|
||||
if not success:
|
||||
self.reply_text("抱歉,视频发送失败", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Failed to upload video")
|
||||
self.reply_text("抱歉,视频上传失败", incoming_message)
|
||||
else:
|
||||
# 其他文件类型
|
||||
logger.info(f"[DingTalk] Sending file: {reply.content}")
|
||||
media_id = self.upload_media(reply.content, media_type="file")
|
||||
if media_id:
|
||||
file_name = os.path.basename(file_path)
|
||||
file_base, file_extension = os.path.splitext(file_name)
|
||||
msg_param = {
|
||||
"mediaId": media_id,
|
||||
"fileName": file_name,
|
||||
"fileType": file_extension[1:] if file_extension else "file"
|
||||
}
|
||||
success = self._send_file_message(
|
||||
access_token,
|
||||
incoming_message,
|
||||
"sampleFile",
|
||||
msg_param,
|
||||
isgroup
|
||||
)
|
||||
if not success:
|
||||
self.reply_text("抱歉,文件发送失败", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Failed to upload file")
|
||||
self.reply_text("抱歉,文件上传失败", incoming_message)
|
||||
return
|
||||
|
||||
# 处理文本消息
|
||||
elif reply.type == ReplyType.TEXT:
|
||||
logger.info(f"[DingTalk] Sending text message, length={len(reply.content)}")
|
||||
if conf().get("dingtalk_card_enabled"):
|
||||
logger.info("[Dingtalk] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
def reply_with_text():
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
def reply_with_at_text():
|
||||
self.reply_text("📢 您有一条新的消息,请查看。", incoming_message)
|
||||
def reply_with_ai_markdown():
|
||||
button_list, markdown_content = self.generate_button_markdown_content(context, reply)
|
||||
self.reply_ai_markdown_button(incoming_message, markdown_content, button_list, "", "📌 内容由AI生成", "",[incoming_message.sender_staff_id])
|
||||
|
||||
if reply.type in [ReplyType.IMAGE_URL, ReplyType.IMAGE, ReplyType.TEXT]:
|
||||
if isgroup:
|
||||
reply_with_ai_markdown()
|
||||
reply_with_at_text()
|
||||
else:
|
||||
reply_with_ai_markdown()
|
||||
else:
|
||||
# 暂不支持其它类型消息回复
|
||||
reply_with_text()
|
||||
else:
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
return
|
||||
|
||||
def _send_file_message(self, access_token: str, incoming_message, msg_key: str, msg_param: dict, is_group: bool) -> bool:
|
||||
"""
|
||||
发送文件/视频消息的通用方法
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
incoming_message: 钉钉消息对象
|
||||
msg_key: 消息类型 (sampleFile, sampleVideo, sampleAudio)
|
||||
msg_param: 消息参数
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
body = {
|
||||
"robotCode": incoming_message.robot_code,
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param),
|
||||
}
|
||||
|
||||
if is_group:
|
||||
# 群聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
body["openConversationId"] = incoming_message.conversation_id
|
||||
else:
|
||||
# 单聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
body["userIds"] = [incoming_message.sender_staff_id]
|
||||
|
||||
try:
|
||||
response = requests.post(url=url, headers=headers, json=body, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
logger.info(f"[DingTalk] File send result: {response.text}")
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Send file error: {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Send file exception: {e}")
|
||||
return False
|
||||
|
||||
def generate_button_markdown_content(self, context, reply):
|
||||
image_url = context.kwargs.get("image_url")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import requests
|
||||
from dingtalk_stream import ChatbotMessage
|
||||
@@ -8,6 +9,7 @@ from channel.chat_message import ChatMessage
|
||||
# -*- coding=utf-8 -*-
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from config import conf
|
||||
|
||||
|
||||
class DingTalkMessage(ChatMessage):
|
||||
@@ -22,6 +24,7 @@ class DingTalkMessage(ChatMessage):
|
||||
self.create_time = event.create_at
|
||||
self.image_content = event.image_content
|
||||
self.rich_text_content = event.rich_text_content
|
||||
self.robot_code = event.robot_code # 机器人编码
|
||||
if event.conversation_type == "1":
|
||||
self.is_group = False
|
||||
else:
|
||||
@@ -36,15 +39,67 @@ class DingTalkMessage(ChatMessage):
|
||||
self.content = event.extensions['content']['recognition'].strip()
|
||||
self.ctype = ContextType.TEXT
|
||||
elif (self.message_type == 'picture') or (self.message_type == 'richText'):
|
||||
self.ctype = ContextType.IMAGE
|
||||
# 钉钉图片类型或富文本类型消息处理
|
||||
image_list = event.get_image_list()
|
||||
if len(image_list) > 0:
|
||||
|
||||
if self.message_type == 'picture' and len(image_list) > 0:
|
||||
# 单张图片消息:下载到工作空间,用于文件缓存
|
||||
self.ctype = ContextType.IMAGE
|
||||
download_code = image_list[0]
|
||||
download_url = image_download_handler.get_image_download_url(download_code)
|
||||
self.content = download_image_file(download_url, TmpDir().path())
|
||||
|
||||
# 下载到工作空间 tmp 目录
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
image_path = download_image_file(download_url, tmp_dir)
|
||||
if image_path:
|
||||
self.content = image_path
|
||||
self.image_path = image_path # 保存图片路径用于缓存
|
||||
logger.info(f"[DingTalk] Downloaded single image to {image_path}")
|
||||
else:
|
||||
self.content = "[图片下载失败]"
|
||||
self.image_path = None
|
||||
|
||||
elif self.message_type == 'richText' and len(image_list) > 0:
|
||||
# 富文本消息:下载所有图片并附加到文本中
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
# 下载到工作空间 tmp 目录
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
# 提取富文本中的文本内容
|
||||
text_content = ""
|
||||
if self.rich_text_content:
|
||||
# rich_text_content 是一个 RichTextContent 对象,需要从中提取文本
|
||||
text_list = event.get_text_list()
|
||||
if text_list:
|
||||
text_content = "".join(text_list).strip()
|
||||
|
||||
# 下载所有图片
|
||||
image_paths = []
|
||||
for download_code in image_list:
|
||||
download_url = image_download_handler.get_image_download_url(download_code)
|
||||
image_path = download_image_file(download_url, tmp_dir)
|
||||
if image_path:
|
||||
image_paths.append(image_path)
|
||||
|
||||
# 构建消息内容:文本 + 图片路径
|
||||
content_parts = []
|
||||
if text_content:
|
||||
content_parts.append(text_content)
|
||||
for img_path in image_paths:
|
||||
content_parts.append(f"[图片: {img_path}]")
|
||||
|
||||
self.content = "\n".join(content_parts) if content_parts else "[富文本消息]"
|
||||
logger.info(f"[DingTalk] Received richText with {len(image_paths)} image(s): {self.content}")
|
||||
else:
|
||||
logger.debug(f"[Dingtalk] messageType :{self.message_type} , imageList isEmpty")
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = "[未找到图片]"
|
||||
logger.debug(f"[DingTalk] messageType: {self.message_type}, imageList isEmpty")
|
||||
|
||||
if self.is_group:
|
||||
self.from_user_id = event.conversation_id
|
||||
@@ -58,27 +113,131 @@ class DingTalkMessage(ChatMessage):
|
||||
|
||||
|
||||
def download_image_file(image_url, temp_dir):
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36'
|
||||
}
|
||||
# 设置代理
|
||||
# self.proxies
|
||||
# , proxies=self.proxies
|
||||
response = requests.get(image_url, headers=headers, stream=True, timeout=60 * 5)
|
||||
if response.status_code == 200:
|
||||
|
||||
# 生成文件名
|
||||
file_name = image_url.split("/")[-1].split("?")[0]
|
||||
|
||||
# 检查临时目录是否存在,如果不存在则创建
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
|
||||
# 将文件保存到临时目录
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
return file_path
|
||||
"""
|
||||
下载图片文件
|
||||
支持两种方式:
|
||||
1. 普通 HTTP(S) URL
|
||||
2. 钉钉 downloadCode: dingtalk://download/{download_code}
|
||||
"""
|
||||
# 检查临时目录是否存在,如果不存在则创建
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
|
||||
# 处理钉钉 downloadCode
|
||||
if image_url.startswith("dingtalk://download/"):
|
||||
download_code = image_url.replace("dingtalk://download/", "")
|
||||
logger.info(f"[DingTalk] Downloading image with downloadCode: {download_code[:20]}...")
|
||||
|
||||
# 需要从外部传入 access_token,这里先用一个临时方案
|
||||
# 从 config 获取 dingtalk_client_id 和 dingtalk_client_secret
|
||||
from config import conf
|
||||
client_id = conf().get("dingtalk_client_id")
|
||||
client_secret = conf().get("dingtalk_client_secret")
|
||||
|
||||
if not client_id or not client_secret:
|
||||
logger.error("[DingTalk] Missing dingtalk_client_id or dingtalk_client_secret")
|
||||
return None
|
||||
|
||||
# 解析 robot_code 和 download_code
|
||||
parts = download_code.split(":", 1)
|
||||
if len(parts) != 2:
|
||||
logger.error(f"[DingTalk] Invalid download_code format (expected robot_code:download_code): {download_code[:50]}")
|
||||
return None
|
||||
|
||||
robot_code, actual_download_code = parts
|
||||
|
||||
# 获取 access_token(使用新版 API)
|
||||
token_url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
token_headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
token_body = {
|
||||
"appKey": client_id,
|
||||
"appSecret": client_secret
|
||||
}
|
||||
|
||||
try:
|
||||
token_response = requests.post(token_url, json=token_body, headers=token_headers, timeout=10)
|
||||
|
||||
if token_response.status_code == 200:
|
||||
token_data = token_response.json()
|
||||
access_token = token_data.get("accessToken")
|
||||
|
||||
if not access_token:
|
||||
logger.error(f"[DingTalk] Failed to get access token: {token_data}")
|
||||
return None
|
||||
|
||||
# 获取下载 URL(使用新版 API)
|
||||
download_api_url = "https://api.dingtalk.com/v1.0/robot/messageFiles/download"
|
||||
download_headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
download_body = {
|
||||
"downloadCode": actual_download_code,
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
download_response = requests.post(download_api_url, json=download_body, headers=download_headers, timeout=10)
|
||||
|
||||
if download_response.status_code == 200:
|
||||
download_data = download_response.json()
|
||||
download_url = download_data.get("downloadUrl")
|
||||
|
||||
if not download_url:
|
||||
logger.error(f"[DingTalk] No downloadUrl in response: {download_data}")
|
||||
return None
|
||||
|
||||
# 从 downloadUrl 下载实际图片
|
||||
image_response = requests.get(download_url, stream=True, timeout=60)
|
||||
|
||||
if image_response.status_code == 200:
|
||||
# 生成文件名(使用 download_code 的 hash,避免特殊字符)
|
||||
import hashlib
|
||||
file_hash = hashlib.md5(actual_download_code.encode()).hexdigest()[:16]
|
||||
file_name = f"{file_hash}.png"
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(image_response.content)
|
||||
|
||||
logger.info(f"[DingTalk] Image downloaded successfully: {file_path}")
|
||||
return file_path
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to download image from URL: {image_response.status_code}")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to get download URL: {download_response.status_code}, {download_response.text}")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to get access token: {token_response.status_code}, {token_response.text}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Exception downloading image: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
# 普通 HTTP(S) URL
|
||||
else:
|
||||
logger.info(f"[Dingtalk] Failed to download image file, {response.content}")
|
||||
return None
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36'
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(image_url, headers=headers, stream=True, timeout=60 * 5)
|
||||
if response.status_code == 200:
|
||||
# 生成文件名
|
||||
file_name = image_url.split("/")[-1].split("?")[0]
|
||||
|
||||
# 将文件保存到临时目录
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
return file_path
|
||||
else:
|
||||
logger.info(f"[Dingtalk] Failed to download image file, {response.content}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Dingtalk] Exception downloading image: {e}")
|
||||
return None
|
||||
|
||||
@@ -55,7 +55,7 @@ class FeiShuChanel(ChatChannel):
|
||||
super().__init__()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
logger.info("[FeiShu] app_id={}, app_secret={}, verification_token={}, event_mode={}".format(
|
||||
logger.debug("[FeiShu] app_id={}, app_secret={}, verification_token={}, event_mode={}".format(
|
||||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token, self.feishu_event_mode))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
@@ -74,7 +74,7 @@ class FeiShuChanel(ChatChannel):
|
||||
|
||||
def _startup_webhook(self):
|
||||
"""启动HTTP服务器接收事件(webhook模式)"""
|
||||
logger.info("[FeiShu] Starting in webhook mode...")
|
||||
logger.debug("[FeiShu] Starting in webhook mode...")
|
||||
urls = (
|
||||
'/', 'channel.feishu.feishu_channel.FeishuController'
|
||||
)
|
||||
@@ -84,7 +84,7 @@ class FeiShuChanel(ChatChannel):
|
||||
|
||||
def _startup_websocket(self):
|
||||
"""启动长连接接收事件(websocket模式)"""
|
||||
logger.info("[FeiShu] Starting in websocket mode...")
|
||||
logger.debug("[FeiShu] Starting in websocket mode...")
|
||||
|
||||
# 创建事件处理器
|
||||
def handle_message_event(data: lark.im.v1.P2ImMessageReceiveV1) -> None:
|
||||
@@ -118,7 +118,7 @@ class FeiShuChanel(ChatChannel):
|
||||
# 在新线程中启动客户端,避免阻塞主线程
|
||||
def start_client():
|
||||
try:
|
||||
logger.info("[FeiShu] Websocket client starting...")
|
||||
logger.debug("[FeiShu] Websocket client starting...")
|
||||
ws_client.start()
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] Websocket client error: {e}", exc_info=True)
|
||||
@@ -127,7 +127,7 @@ class FeiShuChanel(ChatChannel):
|
||||
ws_thread.start()
|
||||
|
||||
# 保持主线程运行
|
||||
logger.info("[FeiShu] Websocket mode started, waiting for events...")
|
||||
logger.info("[FeiShu] ✅ Websocket connected, ready to receive messages")
|
||||
ws_thread.join()
|
||||
|
||||
def _handle_message_event(self, event: dict):
|
||||
@@ -173,6 +173,48 @@ class FeiShuChanel(ChatChannel):
|
||||
if not feishu_msg:
|
||||
return
|
||||
|
||||
# 处理文件缓存逻辑
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
# 获取 session_id(用于缓存关联)
|
||||
if is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
session_id = msg.get("chat_id") # 群共享会话
|
||||
else:
|
||||
session_id = feishu_msg.from_user_id + "_" + msg.get("chat_id")
|
||||
else:
|
||||
session_id = feishu_msg.from_user_id
|
||||
|
||||
# 如果是单张图片消息,缓存起来
|
||||
if feishu_msg.ctype == ContextType.IMAGE:
|
||||
if hasattr(feishu_msg, 'image_path') and feishu_msg.image_path:
|
||||
file_cache.add(session_id, feishu_msg.image_path, file_type='image')
|
||||
logger.info(f"[FeiShu] Image cached for session {session_id}, waiting for user query...")
|
||||
# 单张图片不直接处理,等待用户提问
|
||||
return
|
||||
|
||||
# 如果是文本消息,检查是否有缓存的文件
|
||||
if feishu_msg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
# 将缓存的文件附加到文本消息中
|
||||
file_refs = []
|
||||
for file_info in cached_files:
|
||||
file_path = file_info['path']
|
||||
file_type = file_info['type']
|
||||
if file_type == 'image':
|
||||
file_refs.append(f"[图片: {file_path}]")
|
||||
elif file_type == 'video':
|
||||
file_refs.append(f"[视频: {file_path}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {file_path}]")
|
||||
|
||||
feishu_msg.content = feishu_msg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[FeiShu] Attached {len(cached_files)} cached file(s) to user query")
|
||||
# 清除缓存
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(
|
||||
feishu_msg.ctype,
|
||||
feishu_msg.content,
|
||||
@@ -183,7 +225,7 @@ class FeiShuChanel(ChatChannel):
|
||||
)
|
||||
if context:
|
||||
self.produce(context)
|
||||
logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
|
||||
logger.debug(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
@@ -197,32 +239,69 @@ class FeiShuChanel(ChatChannel):
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
msg_type = "text"
|
||||
logger.info(f"[FeiShu] start send reply message, type={context.type}, content={reply.content}")
|
||||
logger.debug(f"[FeiShu] sending reply, type={context.type}, content={reply.content[:100]}...")
|
||||
reply_content = reply.content
|
||||
content_key = "text"
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
# 图片上传
|
||||
reply_content = self._upload_image_url(reply.content, access_token)
|
||||
if not reply_content:
|
||||
logger.warning("[FeiShu] upload file failed")
|
||||
logger.warning("[FeiShu] upload image failed")
|
||||
return
|
||||
msg_type = "image"
|
||||
content_key = "image_key"
|
||||
if is_group:
|
||||
# 群聊中直接回复
|
||||
elif reply.type == ReplyType.FILE:
|
||||
# 判断是否为视频文件
|
||||
file_path = reply.content
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
|
||||
is_video = file_path.lower().endswith(('.mp4', '.avi', '.mov', '.wmv', '.flv'))
|
||||
|
||||
if is_video:
|
||||
# 视频使用 media 类型,需要上传并获取 file_key 和 duration
|
||||
video_info = self._upload_video_url(reply.content, access_token)
|
||||
if not video_info or not video_info.get('file_key'):
|
||||
logger.warning("[FeiShu] upload video failed")
|
||||
return
|
||||
|
||||
# media 类型需要特殊的 content 格式
|
||||
msg_type = "media"
|
||||
# 注意:media 类型的 content 不使用 content_key,而是完整的 JSON 对象
|
||||
reply_content = {
|
||||
"file_key": video_info['file_key'],
|
||||
"duration": video_info.get('duration', 0) # 视频时长(毫秒)
|
||||
}
|
||||
content_key = None # media 类型不使用单一的 key
|
||||
else:
|
||||
# 其他文件使用 file 类型
|
||||
file_key = self._upload_file_url(reply.content, access_token)
|
||||
if not file_key:
|
||||
logger.warning("[FeiShu] upload file failed")
|
||||
return
|
||||
reply_content = file_key
|
||||
msg_type = "file"
|
||||
content_key = "file_key"
|
||||
|
||||
# Check if we can reply to an existing message (need msg_id)
|
||||
can_reply = is_group and msg and hasattr(msg, 'msg_id') and msg.msg_id
|
||||
|
||||
if can_reply:
|
||||
# 群聊中回复已有消息
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.msg_id}/reply"
|
||||
data = {
|
||||
"msg_type": msg_type,
|
||||
"content": json.dumps({content_key: reply_content})
|
||||
"content": json.dumps(reply_content) if content_key is None else json.dumps({content_key: reply_content})
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10))
|
||||
else:
|
||||
# 发送新消息(私聊或群聊中无msg_id的情况,如定时任务)
|
||||
url = "https://open.feishu.cn/open-apis/im/v1/messages"
|
||||
params = {"receive_id_type": context.get("receive_id_type") or "open_id"}
|
||||
data = {
|
||||
"receive_id": context.get("receiver"),
|
||||
"msg_type": msg_type,
|
||||
"content": json.dumps({content_key: reply_content})
|
||||
"content": json.dumps(reply_content) if content_key is None else json.dumps({content_key: reply_content})
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, params=params, json=data, timeout=(5, 10))
|
||||
res = res.json()
|
||||
@@ -255,7 +334,34 @@ class FeiShuChanel(ChatChannel):
|
||||
|
||||
|
||||
def _upload_image_url(self, img_url, access_token):
|
||||
logger.debug(f"[WX] start download image, img_url={img_url}")
|
||||
logger.debug(f"[FeiShu] start process image, img_url={img_url}")
|
||||
|
||||
# Check if it's a local file path (file:// protocol)
|
||||
if img_url.startswith("file://"):
|
||||
local_path = img_url[7:] # Remove "file://" prefix
|
||||
logger.info(f"[FeiShu] uploading local file: {local_path}")
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"[FeiShu] local file not found: {local_path}")
|
||||
return None
|
||||
|
||||
# Upload directly from local file
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
|
||||
data = {'image_type': 'message'}
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
with open(local_path, "rb") as file:
|
||||
upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
|
||||
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
if response_data.get("code") == 0:
|
||||
return response_data.get("data").get("image_key")
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload failed: {response_data}")
|
||||
return None
|
||||
|
||||
# Original logic for HTTP URLs
|
||||
response = requests.get(img_url)
|
||||
suffix = utils.get_path_suffix(img_url)
|
||||
temp_name = str(uuid.uuid4()) + "." + suffix
|
||||
@@ -278,6 +384,232 @@ class FeiShuChanel(ChatChannel):
|
||||
os.remove(temp_name)
|
||||
return upload_response.json().get("data").get("image_key")
|
||||
|
||||
def _get_video_duration(self, file_path: str) -> int:
|
||||
"""
|
||||
获取视频时长(毫秒)
|
||||
|
||||
Args:
|
||||
file_path: 视频文件路径
|
||||
|
||||
Returns:
|
||||
视频时长(毫秒),如果获取失败返回0
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
# 使用 ffprobe 获取视频时长
|
||||
cmd = [
|
||||
'ffprobe',
|
||||
'-v', 'error',
|
||||
'-show_entries', 'format=duration',
|
||||
'-of', 'default=noprint_wrappers=1:nokey=1',
|
||||
file_path
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
if result.returncode == 0:
|
||||
duration_seconds = float(result.stdout.strip())
|
||||
duration_ms = int(duration_seconds * 1000)
|
||||
logger.info(f"[FeiShu] Video duration: {duration_seconds:.2f}s ({duration_ms}ms)")
|
||||
return duration_ms
|
||||
else:
|
||||
logger.warning(f"[FeiShu] Failed to get video duration via ffprobe: {result.stderr}")
|
||||
return 0
|
||||
except FileNotFoundError:
|
||||
logger.warning("[FeiShu] ffprobe not found, video duration will be 0. Install ffmpeg to fix this.")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Failed to get video duration: {e}")
|
||||
return 0
|
||||
|
||||
def _upload_video_url(self, video_url, access_token):
|
||||
"""
|
||||
Upload video to Feishu and return video info (file_key and duration)
|
||||
Supports:
|
||||
- file:// URLs for local files
|
||||
- http(s):// URLs (download then upload)
|
||||
|
||||
Returns:
|
||||
dict with 'file_key' and 'duration' (milliseconds), or None if failed
|
||||
"""
|
||||
local_path = None
|
||||
temp_file = None
|
||||
|
||||
try:
|
||||
# For file:// URLs (local files), upload directly
|
||||
if video_url.startswith("file://"):
|
||||
local_path = video_url[7:] # Remove file:// prefix
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"[FeiShu] local video file not found: {local_path}")
|
||||
return None
|
||||
else:
|
||||
# For HTTP URLs, download first
|
||||
logger.info(f"[FeiShu] Downloading video from URL: {video_url}")
|
||||
response = requests.get(video_url, timeout=(5, 60))
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[FeiShu] download video failed, status={response.status_code}")
|
||||
return None
|
||||
|
||||
# Save to temp file
|
||||
import uuid
|
||||
file_name = os.path.basename(video_url) or "video.mp4"
|
||||
temp_file = str(uuid.uuid4()) + "_" + file_name
|
||||
|
||||
with open(temp_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
logger.info(f"[FeiShu] Video downloaded, size={len(response.content)} bytes")
|
||||
local_path = temp_file
|
||||
|
||||
# Get video duration
|
||||
duration = self._get_video_duration(local_path)
|
||||
|
||||
# Upload to Feishu
|
||||
file_name = os.path.basename(local_path)
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
file_type_map = {'.mp4': 'mp4'}
|
||||
file_type = file_type_map.get(file_ext, 'mp4')
|
||||
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
|
||||
data = {'file_type': file_type, 'file_name': file_name}
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
with open(local_path, "rb") as file:
|
||||
upload_response = requests.post(
|
||||
upload_url,
|
||||
files={"file": file},
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=(5, 60)
|
||||
)
|
||||
logger.info(f"[FeiShu] upload video response, status={upload_response.status_code}, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
if response_data.get("code") == 0:
|
||||
file_key = response_data.get("data").get("file_key")
|
||||
return {
|
||||
'file_key': file_key,
|
||||
'duration': duration
|
||||
}
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload video failed: {response_data}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] upload video exception: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if temp_file and os.path.exists(temp_file):
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Failed to remove temp file {temp_file}: {e}")
|
||||
|
||||
def _upload_file_url(self, file_url, access_token):
|
||||
"""
|
||||
Upload file to Feishu
|
||||
Supports both local files (file://) and HTTP URLs
|
||||
"""
|
||||
logger.debug(f"[FeiShu] start process file, file_url={file_url}")
|
||||
|
||||
# Check if it's a local file path (file:// protocol)
|
||||
if file_url.startswith("file://"):
|
||||
local_path = file_url[7:] # Remove "file://" prefix
|
||||
logger.info(f"[FeiShu] uploading local file: {local_path}")
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"[FeiShu] local file not found: {local_path}")
|
||||
return None
|
||||
|
||||
# Get file info
|
||||
file_name = os.path.basename(local_path)
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
|
||||
# Determine file type for Feishu API
|
||||
# Feishu supports: opus, mp4, pdf, doc, xls, ppt, stream (other types)
|
||||
file_type_map = {
|
||||
'.opus': 'opus',
|
||||
'.mp4': 'mp4',
|
||||
'.pdf': 'pdf',
|
||||
'.doc': 'doc', '.docx': 'doc',
|
||||
'.xls': 'xls', '.xlsx': 'xls',
|
||||
'.ppt': 'ppt', '.pptx': 'ppt',
|
||||
}
|
||||
file_type = file_type_map.get(file_ext, 'stream') # Default to stream for other types
|
||||
|
||||
# Upload file to Feishu
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
|
||||
data = {'file_type': file_type, 'file_name': file_name}
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
try:
|
||||
with open(local_path, "rb") as file:
|
||||
upload_response = requests.post(
|
||||
upload_url,
|
||||
files={"file": file},
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=(5, 30) # 5s connect, 30s read timeout
|
||||
)
|
||||
logger.info(f"[FeiShu] upload file response, status={upload_response.status_code}, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
if response_data.get("code") == 0:
|
||||
return response_data.get("data").get("file_key")
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload file failed: {response_data}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] upload file exception: {e}")
|
||||
return None
|
||||
|
||||
# For HTTP URLs, download first then upload
|
||||
try:
|
||||
response = requests.get(file_url, timeout=(5, 30))
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[FeiShu] download file failed, status={response.status_code}")
|
||||
return None
|
||||
|
||||
# Save to temp file
|
||||
import uuid
|
||||
file_name = os.path.basename(file_url)
|
||||
temp_name = str(uuid.uuid4()) + "_" + file_name
|
||||
|
||||
with open(temp_name, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
# Upload
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
file_type_map = {
|
||||
'.opus': 'opus', '.mp4': 'mp4', '.pdf': 'pdf',
|
||||
'.doc': 'doc', '.docx': 'doc',
|
||||
'.xls': 'xls', '.xlsx': 'xls',
|
||||
'.ppt': 'ppt', '.pptx': 'ppt',
|
||||
}
|
||||
file_type = file_type_map.get(file_ext, 'stream')
|
||||
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
|
||||
data = {'file_type': file_type, 'file_name': file_name}
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
with open(temp_name, "rb") as file:
|
||||
upload_response = requests.post(upload_url, files={"file": file}, data=data, headers=headers)
|
||||
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
os.remove(temp_name) # Clean up temp file
|
||||
|
||||
if response_data.get("code") == 0:
|
||||
return response_data.get("data").get("file_key")
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload file failed: {response_data}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] upload file from URL exception: {e}")
|
||||
return None
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
@@ -286,13 +618,18 @@ class FeiShuChanel(ChatChannel):
|
||||
|
||||
cmsg = context["msg"]
|
||||
|
||||
# Set session_id based on chat type to ensure proper session isolation
|
||||
# Set session_id based on chat type
|
||||
if cmsg.is_group:
|
||||
# Group chat: combine user_id and group_id to create unique session per user per group
|
||||
# This ensures:
|
||||
# - Same user in different groups have separate conversation histories
|
||||
# - Same user in private chat and group chat have separate histories
|
||||
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
|
||||
# Group chat: check if group_shared_session is enabled
|
||||
if conf().get("group_shared_session", True):
|
||||
# All users in the group share the same session context
|
||||
context["session_id"] = cmsg.other_user_id # group_id
|
||||
else:
|
||||
# Each user has their own session within the group
|
||||
# This ensures:
|
||||
# - Same user in different groups have separate conversation histories
|
||||
# - Same user in private chat and group chat have separate histories
|
||||
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
|
||||
else:
|
||||
# Private chat: use user_id only
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
import json
|
||||
import os
|
||||
import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
from config import conf
|
||||
|
||||
|
||||
class FeishuMessage(ChatMessage):
|
||||
@@ -22,6 +24,119 @@ class FeishuMessage(ChatMessage):
|
||||
self.ctype = ContextType.TEXT
|
||||
content = json.loads(msg.get('content'))
|
||||
self.content = content.get("text").strip()
|
||||
elif msg_type == "image":
|
||||
# 单张图片消息:下载并缓存,等待用户提问时一起发送
|
||||
self.ctype = ContextType.IMAGE
|
||||
content = json.loads(msg.get("content"))
|
||||
image_key = content.get("image_key")
|
||||
|
||||
# 下载图片到工作空间临时目录
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
image_path = os.path.join(tmp_dir, f"{image_key}.png")
|
||||
|
||||
# 下载图片
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.get('message_id')}/resources/{image_key}"
|
||||
headers = {"Authorization": "Bearer " + access_token}
|
||||
params = {"type": "image"}
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
|
||||
if response.status_code == 200:
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"[FeiShu] Downloaded single image, key={image_key}, path={image_path}")
|
||||
self.content = image_path
|
||||
self.image_path = image_path # 保存图片路径
|
||||
else:
|
||||
logger.error(f"[FeiShu] Failed to download single image, key={image_key}, status={response.status_code}")
|
||||
self.content = f"[图片下载失败: {image_key}]"
|
||||
self.image_path = None
|
||||
elif msg_type == "post":
|
||||
# 富文本消息,可能包含图片、文本等多种元素
|
||||
content = json.loads(msg.get("content"))
|
||||
|
||||
# 飞书富文本消息结构:content 直接包含 title 和 content 数组
|
||||
# 不是嵌套在 post 字段下
|
||||
title = content.get("title", "")
|
||||
content_list = content.get("content", [])
|
||||
|
||||
logger.info(f"[FeiShu] Post message - title: '{title}', content_list length: {len(content_list)}")
|
||||
|
||||
# 收集所有图片和文本
|
||||
image_keys = []
|
||||
text_parts = []
|
||||
|
||||
if title:
|
||||
text_parts.append(title)
|
||||
|
||||
for block in content_list:
|
||||
logger.debug(f"[FeiShu] Processing block: {block}")
|
||||
# block 本身就是元素列表
|
||||
if not isinstance(block, list):
|
||||
continue
|
||||
|
||||
for element in block:
|
||||
element_tag = element.get("tag")
|
||||
logger.debug(f"[FeiShu] Element tag: {element_tag}, element: {element}")
|
||||
if element_tag == "img":
|
||||
# 找到图片元素
|
||||
image_key = element.get("image_key")
|
||||
if image_key:
|
||||
image_keys.append(image_key)
|
||||
elif element_tag == "text":
|
||||
# 文本元素
|
||||
text_content = element.get("text", "")
|
||||
if text_content:
|
||||
text_parts.append(text_content)
|
||||
|
||||
logger.info(f"[FeiShu] Parsed - images: {len(image_keys)}, text_parts: {text_parts}")
|
||||
|
||||
# 富文本消息统一作为文本消息处理
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
if image_keys:
|
||||
# 如果包含图片,下载并在文本中引用本地路径
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
# 保存图片路径映射
|
||||
self.image_paths = {}
|
||||
for image_key in image_keys:
|
||||
image_path = os.path.join(tmp_dir, f"{image_key}.png")
|
||||
self.image_paths[image_key] = image_path
|
||||
|
||||
def _download_images():
|
||||
for image_key, image_path in self.image_paths.items():
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{image_key}"
|
||||
headers = {"Authorization": "Bearer " + access_token}
|
||||
params = {"type": "image"}
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
if response.status_code == 200:
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"[FeiShu] Image downloaded from post message, key={image_key}, path={image_path}")
|
||||
else:
|
||||
logger.error(f"[FeiShu] Failed to download image from post, key={image_key}, status={response.status_code}")
|
||||
|
||||
# 立即下载图片,不使用延迟下载
|
||||
# 因为 TEXT 类型消息不会调用 prepare()
|
||||
_download_images()
|
||||
|
||||
# 构建消息内容:文本 + 图片路径
|
||||
content_parts = []
|
||||
if text_parts:
|
||||
content_parts.append("\n".join(text_parts).strip())
|
||||
for image_key, image_path in self.image_paths.items():
|
||||
content_parts.append(f"[图片: {image_path}]")
|
||||
|
||||
self.content = "\n".join(content_parts)
|
||||
logger.info(f"[FeiShu] Received post message with {len(image_keys)} image(s) and text: {self.content}")
|
||||
else:
|
||||
# 纯文本富文本消息
|
||||
self.content = "\n".join(text_parts).strip() if text_parts else "[富文本消息]"
|
||||
logger.info(f"[FeiShu] Received post message (text only): {self.content}")
|
||||
elif msg_type == "file":
|
||||
self.ctype = ContextType.FILE
|
||||
content = json.loads(msg.get("content"))
|
||||
|
||||
100
channel/file_cache.py
Normal file
100
channel/file_cache.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
文件缓存管理器
|
||||
用于缓存单独发送的文件消息(图片、视频、文档等),在用户提问时自动附加
|
||||
"""
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileCache:
|
||||
"""文件缓存管理器,按 session_id 缓存文件,TTL=2分钟"""
|
||||
|
||||
def __init__(self, ttl=120):
|
||||
"""
|
||||
Args:
|
||||
ttl: 缓存过期时间(秒),默认2分钟
|
||||
"""
|
||||
self.cache = {}
|
||||
self.ttl = ttl
|
||||
|
||||
def add(self, session_id: str, file_path: str, file_type: str = "image"):
|
||||
"""
|
||||
添加文件到缓存
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
file_path: 文件本地路径
|
||||
file_type: 文件类型(image, video, file 等)
|
||||
"""
|
||||
if session_id not in self.cache:
|
||||
self.cache[session_id] = {
|
||||
'files': [],
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
# 添加文件(去重)
|
||||
file_info = {'path': file_path, 'type': file_type}
|
||||
if file_info not in self.cache[session_id]['files']:
|
||||
self.cache[session_id]['files'].append(file_info)
|
||||
logger.info(f"[FileCache] Added {file_type} to cache for session {session_id}: {file_path}")
|
||||
|
||||
def get(self, session_id: str) -> list:
|
||||
"""
|
||||
获取缓存的文件列表
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
文件信息列表 [{'path': '...', 'type': 'image'}, ...],如果没有或已过期返回空列表
|
||||
"""
|
||||
if session_id not in self.cache:
|
||||
return []
|
||||
|
||||
item = self.cache[session_id]
|
||||
|
||||
# 检查是否过期
|
||||
if time.time() - item['timestamp'] > self.ttl:
|
||||
logger.info(f"[FileCache] Cache expired for session {session_id}, clearing...")
|
||||
del self.cache[session_id]
|
||||
return []
|
||||
|
||||
return item['files']
|
||||
|
||||
def clear(self, session_id: str):
|
||||
"""
|
||||
清除指定会话的缓存
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
if session_id in self.cache:
|
||||
logger.info(f"[FileCache] Cleared cache for session {session_id}")
|
||||
del self.cache[session_id]
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""清理所有过期的缓存"""
|
||||
current_time = time.time()
|
||||
expired_sessions = []
|
||||
|
||||
for session_id, item in self.cache.items():
|
||||
if current_time - item['timestamp'] > self.ttl:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
del self.cache[session_id]
|
||||
logger.debug(f"[FileCache] Cleaned up expired cache for session {session_id}")
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"[FileCache] Cleaned up {len(expired_sessions)} expired cache(s)")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_file_cache = FileCache()
|
||||
|
||||
|
||||
def get_file_cache() -> FileCache:
|
||||
"""获取全局文件缓存实例"""
|
||||
return _file_cache
|
||||
@@ -200,12 +200,12 @@ class WebChannel(ChatChannel):
|
||||
logger.info("""[WebChannel] 当前channel为web,可修改 config.json 配置文件中的 channel_type 字段进行切换。全部可用类型为:
|
||||
1. web: 网页
|
||||
2. terminal: 终端
|
||||
3. wechatmp: 个人公众号
|
||||
4. wechatmp_service: 企业公众号
|
||||
3. feishu: 飞书
|
||||
4. dingtalk: 钉钉
|
||||
5. wechatcom_app: 企微自建应用
|
||||
6. dingtalk: 钉钉
|
||||
7. feishu: 飞书""")
|
||||
logger.info(f"Web对话网页已运行, 请使用浏览器访问 http://localhost:{port}/chat (本地运行) 或 http://ip:{port}/chat (服务器运行)")
|
||||
6. wechatmp: 个人公众号
|
||||
7. wechatmp_service: 企业公众号""")
|
||||
logger.info(f"✅ Web对话网页已运行, 请使用浏览器访问 http://localhost:{port}/chat (本地运行) 或 http://ip:{port}/chat (服务器运行)")
|
||||
|
||||
# 确保静态文件目录存在
|
||||
static_dir = os.path.join(os.path.dirname(__file__), 'static')
|
||||
|
||||
@@ -1,38 +1,23 @@
|
||||
{
|
||||
"channel_type": "web",
|
||||
"model": "",
|
||||
"model": "claude-sonnet-4-5",
|
||||
"open_ai_api_key": "YOUR API KEY",
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"claude_api_key": "YOUR API KEY",
|
||||
"text_to_image": "dall-e-2",
|
||||
"claude_api_base": "https://api.anthropic.com/v1",
|
||||
"gemini_api_key": "YOUR API KEY",
|
||||
"gemini_api_base": "https://generativelanguage.googleapis.com",
|
||||
"voice_to_text": "openai",
|
||||
"text_to_voice": "openai",
|
||||
"proxy": "",
|
||||
"hot_reload": false,
|
||||
"single_chat_prefix": [
|
||||
"bot",
|
||||
"@bot"
|
||||
],
|
||||
"single_chat_reply_prefix": "[bot] ",
|
||||
"group_chat_prefix": [
|
||||
"@bot"
|
||||
],
|
||||
"group_name_white_list": [
|
||||
"Agent测试群",
|
||||
"ChatGPT测试群2"
|
||||
],
|
||||
"image_create_prefix": [
|
||||
"画"
|
||||
],
|
||||
"voice_reply_voice": false,
|
||||
"speech_recognition": true,
|
||||
"group_speech_recognition": false,
|
||||
"voice_reply_voice": false,
|
||||
"conversation_max_tokens": 2500,
|
||||
"expires_in_seconds": 3600,
|
||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"temperature": 0.7,
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"proxy": "",
|
||||
"use_linkai": false,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": "",
|
||||
"agent": false
|
||||
"agent": true,
|
||||
"agent_max_context_tokens": 40000,
|
||||
"agent_max_context_turns": 30,
|
||||
"agent_max_steps": 20
|
||||
}
|
||||
|
||||
13
config.py
13
config.py
@@ -15,6 +15,8 @@ available_setting = {
|
||||
"open_ai_api_key": "", # openai api key
|
||||
# openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"claude_api_base": "https://api.anthropic.com/v1", # claude api base
|
||||
"gemini_api_base": "https://generativelanguage.googleapis.com", # gemini api base
|
||||
"proxy": "", # openai使用的代理
|
||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"model": "gpt-3.5-turbo", # 可选择: gpt-4o, pt-4o-mini, gpt-4-turbo, claude-3-sonnet, wenxin, moonshot, qwen-turbo, xunfei, glm-4, minimax, gemini等模型,全部可选模型详见common/const.py文件
|
||||
@@ -35,6 +37,7 @@ available_setting = {
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"group_shared_session": True, # 群聊是否共享会话上下文(所有成员共享),默认为True。False时每个用户在群内有独立会话
|
||||
"nick_name_black_list": [], # 用户昵称黑名单
|
||||
"group_welcome_msg": "", # 配置新人进群固定欢迎语,不配置则使用随机风格欢迎
|
||||
"trigger_by_self": False, # 是否允许机器人触发
|
||||
@@ -184,9 +187,11 @@ available_setting = {
|
||||
"Minimax_group_id": "",
|
||||
"Minimax_base_url": "",
|
||||
"web_port": 9899,
|
||||
"agent": False, # 是否开启Agent模式
|
||||
"agent": True, # 是否开启Agent模式
|
||||
"agent_workspace": "~/cow", # agent工作空间路径,用于存储skills、memory等
|
||||
"bocha_api_key": ""
|
||||
"agent_max_context_tokens": 40000, # Agent模式下最大上下文tokens
|
||||
"agent_max_context_turns": 30, # Agent模式下最大上下文轮次
|
||||
"agent_max_steps": 20, # Agent模式下单次运行最大决策步数
|
||||
}
|
||||
|
||||
|
||||
@@ -203,13 +208,13 @@ class Config(dict):
|
||||
def __getitem__(self, key):
|
||||
# 跳过以下划线开头的注释字段
|
||||
if not key.startswith("_") and key not in available_setting:
|
||||
raise Exception("key {} not in available_setting".format(key))
|
||||
logger.warning("[Config] key '{}' not in available_setting, may not take effect".format(key))
|
||||
return super().__getitem__(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# 跳过以下划线开头的注释字段
|
||||
if not key.startswith("_") and key not in available_setting:
|
||||
raise Exception("key {} not in available_setting".format(key))
|
||||
logger.warning("[Config] key '{}' not in available_setting, may not take effect".format(key))
|
||||
return super().__setitem__(key, value)
|
||||
|
||||
def get(self, key, default=None):
|
||||
|
||||
@@ -9,9 +9,9 @@ import openai.error
|
||||
import broadscope_bailian
|
||||
from broadscope_bailian import ChatQaMessage
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.ali.ali_qwen_session import AliQwenSession
|
||||
from bot.session_manager import SessionManager
|
||||
from models.bot import Bot
|
||||
from models.ali.ali_qwen_session import AliQwenSession
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
@@ -1,4 +1,4 @@
|
||||
from bot.session_manager import Session
|
||||
from models.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import requests
|
||||
|
||||
from bot.bot import Bot
|
||||
from models.bot import Bot
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
import requests
|
||||
import json
|
||||
from common import const
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from models.bot import Bot
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from models.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
BAIDU_API_KEY = conf().get("baidu_wenxin_api_key")
|
||||
BAIDU_SECRET_KEY = conf().get("baidu_wenxin_secret_key")
|
||||
@@ -1,4 +1,4 @@
|
||||
from bot.session_manager import Session
|
||||
from models.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
@@ -12,64 +12,64 @@ def create_bot(bot_type):
|
||||
"""
|
||||
if bot_type == const.BAIDU:
|
||||
# 替换Baidu Unit为Baidu文心千帆对话接口
|
||||
# from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
# from models.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
# return BaiduUnitBot()
|
||||
from bot.baidu.baidu_wenxin import BaiduWenxinBot
|
||||
from models.baidu.baidu_wenxin import BaiduWenxinBot
|
||||
return BaiduWenxinBot()
|
||||
|
||||
elif bot_type == const.CHATGPT:
|
||||
# ChatGPT 网页端web接口
|
||||
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
|
||||
from models.chatgpt.chat_gpt_bot import ChatGPTBot
|
||||
return ChatGPTBot()
|
||||
|
||||
elif bot_type == const.OPEN_AI:
|
||||
# OpenAI 官方对话模型API
|
||||
from bot.openai.open_ai_bot import OpenAIBot
|
||||
from models.openai.open_ai_bot import OpenAIBot
|
||||
return OpenAIBot()
|
||||
|
||||
elif bot_type == const.CHATGPTONAZURE:
|
||||
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
||||
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
||||
from models.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
||||
return AzureChatGPTBot()
|
||||
|
||||
elif bot_type == const.XUNFEI:
|
||||
from bot.xunfei.xunfei_spark_bot import XunFeiBot
|
||||
from models.xunfei.xunfei_spark_bot import XunFeiBot
|
||||
return XunFeiBot()
|
||||
|
||||
elif bot_type == const.LINKAI:
|
||||
from bot.linkai.link_ai_bot import LinkAIBot
|
||||
from models.linkai.link_ai_bot import LinkAIBot
|
||||
return LinkAIBot()
|
||||
|
||||
elif bot_type == const.CLAUDEAI:
|
||||
from bot.claude.claude_ai_bot import ClaudeAIBot
|
||||
from models.claude.claude_ai_bot import ClaudeAIBot
|
||||
return ClaudeAIBot()
|
||||
elif bot_type == const.CLAUDEAPI:
|
||||
from bot.claudeapi.claude_api_bot import ClaudeAPIBot
|
||||
from models.claudeapi.claude_api_bot import ClaudeAPIBot
|
||||
return ClaudeAPIBot()
|
||||
elif bot_type == const.QWEN:
|
||||
from bot.ali.ali_qwen_bot import AliQwenBot
|
||||
from models.ali.ali_qwen_bot import AliQwenBot
|
||||
return AliQwenBot()
|
||||
elif bot_type == const.QWEN_DASHSCOPE:
|
||||
from bot.dashscope.dashscope_bot import DashscopeBot
|
||||
from models.dashscope.dashscope_bot import DashscopeBot
|
||||
return DashscopeBot()
|
||||
elif bot_type == const.GEMINI:
|
||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot
|
||||
from models.gemini.google_gemini_bot import GoogleGeminiBot
|
||||
return GoogleGeminiBot()
|
||||
|
||||
elif bot_type == const.ZHIPU_AI:
|
||||
from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
|
||||
from models.zhipuai.zhipuai_bot import ZHIPUAIBot
|
||||
return ZHIPUAIBot()
|
||||
|
||||
elif bot_type == const.MOONSHOT:
|
||||
from bot.moonshot.moonshot_bot import MoonshotBot
|
||||
from models.moonshot.moonshot_bot import MoonshotBot
|
||||
return MoonshotBot()
|
||||
|
||||
elif bot_type == const.MiniMax:
|
||||
from bot.minimax.minimax_bot import MinimaxBot
|
||||
from models.minimax.minimax_bot import MinimaxBot
|
||||
return MinimaxBot()
|
||||
|
||||
elif bot_type == const.MODELSCOPE:
|
||||
from bot.modelscope.modelscope_bot import ModelScopeBot
|
||||
from models.modelscope.modelscope_bot import ModelScopeBot
|
||||
return ModelScopeBot()
|
||||
|
||||
|
||||
@@ -7,17 +7,17 @@ import openai
|
||||
import openai.error
|
||||
import requests
|
||||
from common import const
|
||||
from bot.bot import Bot
|
||||
from bot.openai_compatible_bot import OpenAICompatibleBot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from models.bot import Bot
|
||||
from models.openai_compatible_bot import OpenAICompatibleBot
|
||||
from models.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from models.openai.open_ai_image import OpenAIImage
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from config import conf, load_config
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from models.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
@@ -1,4 +1,4 @@
|
||||
from bot.session_manager import Session
|
||||
from models.session_manager import Session
|
||||
from common.log import logger
|
||||
from common import const
|
||||
|
||||
@@ -5,9 +5,9 @@ import time
|
||||
|
||||
import requests
|
||||
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from models.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from models.bot import Bot
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common import const
|
||||
@@ -16,7 +16,7 @@ from config import conf
|
||||
|
||||
# Optional OpenAI image support
|
||||
try:
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from models.openai.open_ai_image import OpenAIImage
|
||||
_openai_image_available = True
|
||||
except Exception as e:
|
||||
logger.warning(f"OpenAI image support not available: {e}")
|
||||
@@ -31,7 +31,7 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.api_key = conf().get("claude_api_key")
|
||||
self.api_base = conf().get("open_ai_api_base") or "https://api.anthropic.com/v1"
|
||||
self.api_base = conf().get("claude_api_base") or "https://api.anthropic.com/v1"
|
||||
self.proxy = conf().get("proxy", None)
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "text-davinci-003")
|
||||
|
||||
@@ -365,6 +365,7 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
# Track tool use state
|
||||
tool_uses_map = {} # {index: {id, name, input}}
|
||||
current_tool_use_index = -1
|
||||
stop_reason = None # Track stop reason from Claude
|
||||
|
||||
try:
|
||||
# Make streaming HTTP request
|
||||
@@ -440,6 +441,12 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
tool_uses_map[current_tool_use_index]["input"] += delta.get("partial_json", "")
|
||||
|
||||
elif event_type == "message_delta":
|
||||
# Extract stop_reason from delta
|
||||
delta = event.get("delta", {})
|
||||
if "stop_reason" in delta:
|
||||
stop_reason = delta.get("stop_reason")
|
||||
logger.info(f"[Claude] Stream stop_reason: {stop_reason}")
|
||||
|
||||
# Message complete - yield tool calls if any
|
||||
if tool_uses_map:
|
||||
for idx in sorted(tool_uses_map.keys()):
|
||||
@@ -462,9 +469,13 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": None
|
||||
"finish_reason": stop_reason
|
||||
}]
|
||||
}
|
||||
|
||||
elif event_type == "message_stop":
|
||||
# Final event - log completion
|
||||
logger.debug(f"[Claude] Stream completed with stop_reason: {stop_reason}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
@@ -1,7 +1,7 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from models.bot import Bot
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
@@ -1,4 +1,4 @@
|
||||
from bot.session_manager import Session
|
||||
from models.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@@ -9,15 +9,15 @@ Google gemini bot
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
from bot.bot import Bot
|
||||
from models.bot import Bot
|
||||
import google.generativeai as genai
|
||||
from bot.session_manager import SessionManager
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from models.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from models.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
||||
|
||||
|
||||
@@ -33,14 +33,11 @@ class GoogleGeminiBot(Bot):
|
||||
if self.model == "gemini":
|
||||
self.model = "gemini-pro"
|
||||
|
||||
# 支持自定义API base地址,复用open_ai_api_base配置
|
||||
self.api_base = conf().get("open_ai_api_base", "").strip()
|
||||
# 支持自定义API base地址
|
||||
self.api_base = conf().get("gemini_api_base", "").strip()
|
||||
if self.api_base:
|
||||
# 移除末尾的斜杠
|
||||
self.api_base = self.api_base.rstrip('/')
|
||||
# 如果配置的是OpenAI的地址,则使用默认的Gemini地址
|
||||
if "api.openai.com" in self.api_base or not self.api_base:
|
||||
self.api_base = "https://generativelanguage.googleapis.com"
|
||||
logger.info(f"[Gemini] Using custom API base: {self.api_base}")
|
||||
else:
|
||||
self.api_base = "https://generativelanguage.googleapis.com"
|
||||
@@ -254,7 +251,6 @@ class GoogleGeminiBot(Bot):
|
||||
gemini_tools = self._convert_tools_to_gemini_rest_format(tools)
|
||||
if gemini_tools:
|
||||
payload["tools"] = gemini_tools
|
||||
logger.debug(f"[Gemini] Added {len(tools)} tools to request")
|
||||
|
||||
# Make REST API call
|
||||
base_url = f"{self.api_base}/v1beta"
|
||||
@@ -267,8 +263,6 @@ class GoogleGeminiBot(Bot):
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
logger.debug(f"[Gemini] REST API call: {endpoint}")
|
||||
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
headers=headers,
|
||||
@@ -339,8 +333,6 @@ class GoogleGeminiBot(Bot):
|
||||
logger.warning(f"[Gemini] Skipping tool without name: {tool}")
|
||||
continue
|
||||
|
||||
logger.debug(f"[Gemini] Converting tool: {name}")
|
||||
|
||||
function_declarations.append({
|
||||
"name": name,
|
||||
"description": description,
|
||||
@@ -464,7 +456,6 @@ class GoogleGeminiBot(Bot):
|
||||
try:
|
||||
chunk_data = json.loads(line)
|
||||
chunk_count += 1
|
||||
logger.debug(f"[Gemini] Stream chunk: {json.dumps(chunk_data, ensure_ascii=False)[:200]}")
|
||||
|
||||
candidates = chunk_data.get("candidates", [])
|
||||
if not candidates:
|
||||
@@ -489,7 +480,6 @@ class GoogleGeminiBot(Bot):
|
||||
for part in parts:
|
||||
if "text" in part and part["text"]:
|
||||
has_content = True
|
||||
logger.debug(f"[Gemini] Streaming text: {part['text'][:50]}...")
|
||||
yield {
|
||||
"id": f"chatcmpl-{time.time()}",
|
||||
"object": "chat.completion.chunk",
|
||||
@@ -505,7 +495,7 @@ class GoogleGeminiBot(Bot):
|
||||
# Collect function calls
|
||||
if "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
logger.debug(f"[Gemini] Function call detected: {fc.get('name')}")
|
||||
logger.info(f"[Gemini] Function call: {fc.get('name')}")
|
||||
all_tool_calls.append({
|
||||
"index": len(all_tool_calls), # Add index to differentiate multiple tool calls
|
||||
"id": f"call_{int(time.time() * 1000000)}_{len(all_tool_calls)}",
|
||||
@@ -522,7 +512,6 @@ class GoogleGeminiBot(Bot):
|
||||
|
||||
# Send tool calls if any were collected
|
||||
if all_tool_calls and not has_sent_tool_calls:
|
||||
logger.debug(f"[Gemini] Stream detected {len(all_tool_calls)} tool calls")
|
||||
yield {
|
||||
"id": f"chatcmpl-{time.time()}",
|
||||
"object": "chat.completion.chunk",
|
||||
@@ -536,14 +525,6 @@ class GoogleGeminiBot(Bot):
|
||||
}
|
||||
has_sent_tool_calls = True
|
||||
|
||||
# Log summary (only if there's something interesting)
|
||||
if not has_content and not all_tool_calls:
|
||||
logger.debug(f"[Gemini] Stream complete: has_content={has_content}, tool_calls={len(all_tool_calls)}")
|
||||
elif all_tool_calls:
|
||||
logger.debug(f"[Gemini] Stream complete: {len(all_tool_calls)} tool calls")
|
||||
else:
|
||||
logger.debug(f"[Gemini] Stream complete: text response")
|
||||
|
||||
# 如果返回空响应,记录详细警告
|
||||
if not has_content and not all_tool_calls:
|
||||
logger.warning(f"[Gemini] ⚠️ Empty response detected!")
|
||||
@@ -6,10 +6,10 @@ import time
|
||||
import requests
|
||||
import json
|
||||
import config
|
||||
from bot.bot import Bot
|
||||
from bot.openai_compatible_bot import OpenAICompatibleBot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.session_manager import SessionManager
|
||||
from models.bot import Bot
|
||||
from models.openai_compatible_bot import OpenAICompatibleBot
|
||||
from models.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
@@ -4,14 +4,14 @@ import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.minimax.minimax_session import MinimaxSession
|
||||
from bot.session_manager import SessionManager
|
||||
from models.bot import Bot
|
||||
from models.minimax.minimax_session import MinimaxSession
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from models.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
import requests
|
||||
from common import const
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from bot.session_manager import Session
|
||||
from models.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
@@ -4,8 +4,8 @@ import time
|
||||
import json
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from models.bot import Bot
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
@@ -1,4 +1,4 @@
|
||||
from bot.session_manager import Session
|
||||
from models.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from models.bot import Bot
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
@@ -1,4 +1,4 @@
|
||||
from bot.session_manager import Session
|
||||
from models.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@ import time
|
||||
import openai
|
||||
import openai.error
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai_compatible_bot import OpenAICompatibleBot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.openai.open_ai_session import OpenAISession
|
||||
from bot.session_manager import SessionManager
|
||||
from models.bot import Bot
|
||||
from models.openai_compatible_bot import OpenAICompatibleBot
|
||||
from models.openai.open_ai_image import OpenAIImage
|
||||
from models.openai.open_ai_session import OpenAISession
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
@@ -158,7 +158,7 @@ class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
|
||||
# Build request parameters for ChatCompletion
|
||||
request_params = {
|
||||
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
|
||||
"model": kwargs.get("model", conf().get("model") or "gpt-4.1"),
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", conf().get("temperature", 0.9)),
|
||||
"top_p": kwargs.get("top_p", 1),
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
|
||||
import openai
|
||||
from bot.openai.openai_compat import RateLimitError
|
||||
from models.openai.openai_compat import RateLimitError
|
||||
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
@@ -1,4 +1,4 @@
|
||||
from bot.session_manager import Session
|
||||
from models.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests, json
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from models.bot import Bot
|
||||
from models.session_manager import SessionManager
|
||||
from models.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
@@ -1,4 +1,4 @@
|
||||
from bot.session_manager import Session
|
||||
from models.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.zhipuai.zhipu_ai_session import ZhipuAISession
|
||||
from bot.zhipuai.zhipu_ai_image import ZhipuAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from models.bot import Bot
|
||||
from models.zhipuai.zhipu_ai_session import ZhipuAISession
|
||||
from models.zhipuai.zhipu_ai_image import ZhipuAIImage
|
||||
from models.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
@@ -32,7 +32,7 @@ class AgentPlugin(Plugin):
|
||||
self.config = self._load_config()
|
||||
self.tool_manager = ToolManager()
|
||||
self.tool_manager.load_tools(config_dict=self.config.get("tools"))
|
||||
logger.info("[agent] inited")
|
||||
logger.debug("[agent] inited")
|
||||
|
||||
def _load_config(self) -> Dict:
|
||||
"""Load configuration from config.yaml file."""
|
||||
|
||||
@@ -49,7 +49,7 @@ class Banwords(Plugin):
|
||||
if conf.get("reply_filter", True):
|
||||
self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply
|
||||
self.reply_action = conf.get("reply_action", "ignore")
|
||||
logger.info("[Banwords] inited")
|
||||
logger.debug("[Banwords] inited")
|
||||
except Exception as e:
|
||||
logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
|
||||
raise e
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
## 插件说明
|
||||
|
||||
利用百度UNIT实现智能对话
|
||||
|
||||
- 1.解决问题:chatgpt无法处理的指令,交给百度UNIT处理如:天气,日期时间,数学运算等
|
||||
- 2.如问时间:现在几点钟,今天几号
|
||||
- 3.如问天气:明天广州天气怎么样,这个周末深圳会不会下雨
|
||||
- 4.如问数学运算:23+45=多少,100-23=多少,35转化为二进制是多少?
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 获取apikey
|
||||
|
||||
在百度UNIT官网上自己创建应用,申请百度机器人,可以把预先训练好的模型导入到自己的应用中,
|
||||
|
||||
see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087095fc10c8377aaf https://console.bce.baidu.com/ai平台申请
|
||||
|
||||
### 配置文件
|
||||
|
||||
将文件夹中`config.json.template`复制为`config.json`。
|
||||
|
||||
在其中填写百度UNIT官网上获取应用的API Key和Secret Key
|
||||
|
||||
``` json
|
||||
{
|
||||
"service_id": "s...", #"机器人ID"
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
}
|
||||
```
|
||||
@@ -1 +0,0 @@
|
||||
from .bdunit import *
|
||||
@@ -1,252 +0,0 @@
|
||||
# encoding:utf-8
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from uuid import getnode as get_mac
|
||||
|
||||
import requests
|
||||
|
||||
import plugins
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from plugins import *
|
||||
|
||||
"""利用百度UNIT实现智能对话
|
||||
如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
|
||||
"""
|
||||
|
||||
|
||||
@plugins.register(
|
||||
name="BDunit",
|
||||
desire_priority=0,
|
||||
hidden=True,
|
||||
desc="Baidu unit bot system",
|
||||
version="0.1",
|
||||
author="jackson",
|
||||
)
|
||||
class BDunit(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
conf = super().load_config()
|
||||
if not conf:
|
||||
raise Exception("config.json not found")
|
||||
self.service_id = conf["service_id"]
|
||||
self.api_key = conf["api_key"]
|
||||
self.secret_key = conf["secret_key"]
|
||||
self.access_token = self.get_token()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[BDunit] inited")
|
||||
except Exception as e:
|
||||
logger.warn("[BDunit] init failed, ignore ")
|
||||
raise e
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
if e_context["context"].type != ContextType.TEXT:
|
||||
return
|
||||
|
||||
content = e_context["context"].content
|
||||
logger.debug("[BDunit] on_handle_context. content: %s" % content)
|
||||
parsed = self.getUnit2(content)
|
||||
intent = self.getIntent(parsed)
|
||||
if intent: # 找到意图
|
||||
logger.debug("[BDunit] Baidu_AI Intent= %s", intent)
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = self.getSay(parsed)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
else:
|
||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||
|
||||
def get_help_text(self, **kwargs):
|
||||
help_text = "本插件会处理询问实时日期时间,天气,数学运算等问题,这些技能由您的百度智能对话UNIT决定\n"
|
||||
return help_text
|
||||
|
||||
def get_token(self):
|
||||
"""获取访问百度UUNIT 的access_token
|
||||
#param api_key: UNIT apk_key
|
||||
#param secret_key: UNIT secret_key
|
||||
Returns:
|
||||
string: access_token
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
|
||||
payload = ""
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
|
||||
# print(response.text)
|
||||
return response.json()["access_token"]
|
||||
|
||||
def getUnit(self, query):
|
||||
"""
|
||||
NLU 解析version 3.0
|
||||
:param query: 用户的指令字符串
|
||||
:returns: UNIT 解析结果。如果解析失败,返回 None
|
||||
"""
|
||||
|
||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
|
||||
request = {
|
||||
"query": query,
|
||||
"user_id": str(get_mac())[:32],
|
||||
"terminal_id": "88888",
|
||||
}
|
||||
body = {
|
||||
"log_id": str(uuid.uuid1()),
|
||||
"version": "3.0",
|
||||
"service_id": self.service_id,
|
||||
"session_id": str(uuid.uuid1()),
|
||||
"request": request,
|
||||
}
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(url, json=body, headers=headers)
|
||||
return json.loads(response.text)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def getUnit2(self, query):
|
||||
"""
|
||||
NLU 解析 version 2.0
|
||||
|
||||
:param query: 用户的指令字符串
|
||||
:returns: UNIT 解析结果。如果解析失败,返回 None
|
||||
"""
|
||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
|
||||
request = {"query": query, "user_id": str(get_mac())[:32]}
|
||||
body = {
|
||||
"log_id": str(uuid.uuid1()),
|
||||
"version": "2.0",
|
||||
"service_id": self.service_id,
|
||||
"session_id": str(uuid.uuid1()),
|
||||
"request": request,
|
||||
}
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(url, json=body, headers=headers)
|
||||
return json.loads(response.text)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def getIntent(self, parsed):
|
||||
"""
|
||||
提取意图
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:returns: 意图数组
|
||||
"""
|
||||
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||
try:
|
||||
return parsed["result"]["response_list"][0]["schema"]["intent"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return ""
|
||||
else:
|
||||
return ""
|
||||
|
||||
def hasIntent(self, parsed, intent):
|
||||
"""
|
||||
判断是否包含某个意图
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:returns: True: 包含; False: 不包含
|
||||
"""
|
||||
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||
response_list = parsed["result"]["response_list"]
|
||||
for response in response_list:
|
||||
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
def getSlots(self, parsed, intent=""):
|
||||
"""
|
||||
提取某个意图的所有词槽
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:returns: 词槽列表。你可以通过 name 属性筛选词槽,
|
||||
再通过 normalized_word 属性取出相应的值
|
||||
"""
|
||||
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||
response_list = parsed["result"]["response_list"]
|
||||
if intent == "":
|
||||
try:
|
||||
return parsed["result"]["response_list"][0]["schema"]["slots"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return []
|
||||
for response in response_list:
|
||||
if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
|
||||
return response["schema"]["slots"]
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
||||
def getSlotWords(self, parsed, intent, name):
|
||||
"""
|
||||
找出命中某个词槽的内容
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:param name: 词槽名
|
||||
:returns: 命中该词槽的值的列表。
|
||||
"""
|
||||
slots = self.getSlots(parsed, intent)
|
||||
words = []
|
||||
for slot in slots:
|
||||
if slot["name"] == name:
|
||||
words.append(slot["normalized_word"])
|
||||
return words
|
||||
|
||||
def getSayByConfidence(self, parsed):
|
||||
"""
|
||||
提取 UNIT 置信度最高的回复文本
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:returns: UNIT 的回复文本
|
||||
"""
|
||||
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||
response_list = parsed["result"]["response_list"]
|
||||
answer = {}
|
||||
for response in response_list:
|
||||
if (
|
||||
"schema" in response
|
||||
and "intent_confidence" in response["schema"]
|
||||
and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
|
||||
):
|
||||
answer = response
|
||||
return answer["action_list"][0]["say"]
|
||||
else:
|
||||
return ""
|
||||
|
||||
def getSay(self, parsed, intent=""):
|
||||
"""
|
||||
提取 UNIT 的回复文本
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:returns: UNIT 的回复文本
|
||||
"""
|
||||
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||
response_list = parsed["result"]["response_list"]
|
||||
if intent == "":
|
||||
try:
|
||||
return response_list[0]["action_list"][0]["say"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return ""
|
||||
for response in response_list:
|
||||
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
|
||||
try:
|
||||
return response["action_list"][0]["say"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return ""
|
||||
return ""
|
||||
else:
|
||||
return ""
|
||||
@@ -1,5 +0,0 @@
|
||||
{
|
||||
"service_id": "s...",
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
}
|
||||
@@ -8,17 +8,6 @@
|
||||
"reply_filter": true,
|
||||
"reply_action": "ignore"
|
||||
},
|
||||
"tool": {
|
||||
"tools": [
|
||||
"url-get",
|
||||
"meteo-weather"
|
||||
],
|
||||
"kwargs": {
|
||||
"top_k_results": 2,
|
||||
"no_default": false,
|
||||
"model_name": "gpt-3.5-turbo"
|
||||
}
|
||||
},
|
||||
"linkai": {
|
||||
"group_app_map": {
|
||||
"测试群1": "default",
|
||||
|
||||
@@ -53,7 +53,7 @@ class Dungeon(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[Dungeon] inited")
|
||||
logger.debug("[Dungeon] inited")
|
||||
# 目前没有设计session过期事件,这里先暂时使用过期字典
|
||||
if conf().get("expires_in_seconds"):
|
||||
self.games = ExpiredDict(conf().get("expires_in_seconds"))
|
||||
|
||||
@@ -20,7 +20,7 @@ class Finish(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[Finish] inited")
|
||||
logger.debug("[Finish] inited")
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
if e_context["context"].type != ContextType.TEXT:
|
||||
|
||||
@@ -207,7 +207,7 @@ class Godcmd(Plugin):
|
||||
self.isrunning = True # 机器人是否运行中
|
||||
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[Godcmd] inited")
|
||||
logger.debug("[Godcmd] inited")
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
context_type = e_context["context"].type
|
||||
|
||||
@@ -35,7 +35,7 @@ class Hello(Plugin):
|
||||
self.group_welc_prompt = self.config.get("group_welc_prompt", self.group_welc_prompt)
|
||||
self.group_exit_prompt = self.config.get("group_exit_prompt", self.group_exit_prompt)
|
||||
self.patpat_prompt = self.config.get("patpat_prompt", self.patpat_prompt)
|
||||
logger.info("[Hello] inited")
|
||||
logger.debug("[Hello] inited")
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
except Exception as e:
|
||||
logger.error(f"[Hello]初始化异常:{e}")
|
||||
|
||||
@@ -37,9 +37,9 @@ class Keyword(Plugin):
|
||||
# 加载关键词
|
||||
self.keyword = conf["keyword"]
|
||||
|
||||
logger.info("[keyword] {}".format(self.keyword))
|
||||
logger.debug("[keyword] {}".format(self.keyword))
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[keyword] inited.")
|
||||
logger.debug("[keyword] inited.")
|
||||
except Exception as e:
|
||||
logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .")
|
||||
raise e
|
||||
|
||||
@@ -32,7 +32,7 @@ class LinkAI(Plugin):
|
||||
self.sum_config = {}
|
||||
if self.config:
|
||||
self.sum_config = self.config.get("summary")
|
||||
logger.info(f"[LinkAI] inited, config={self.config}")
|
||||
logger.debug(f"[LinkAI] inited, config={self.config}")
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
"""
|
||||
|
||||
@@ -18,14 +18,12 @@ class Plugin:
|
||||
if not plugin_conf:
|
||||
# 全局配置不存在,则获取插件目录下的配置
|
||||
plugin_config_path = os.path.join(self.path, "config.json")
|
||||
logger.debug(f"loading plugin config, plugin_config_path={plugin_config_path}, exist={os.path.exists(plugin_config_path)}")
|
||||
if os.path.exists(plugin_config_path):
|
||||
with open(plugin_config_path, "r", encoding="utf-8") as f:
|
||||
plugin_conf = json.load(f)
|
||||
|
||||
# 写入全局配置内存
|
||||
write_plugin_config({self.name: plugin_conf})
|
||||
logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
|
||||
return plugin_conf
|
||||
|
||||
def save_config(self, config: dict):
|
||||
|
||||
@@ -38,7 +38,7 @@ class PluginManager:
|
||||
if self.current_plugin_path == None:
|
||||
raise Exception("Plugin path not set")
|
||||
self.plugins[name.upper()] = plugincls
|
||||
logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
|
||||
logger.debug("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -47,7 +47,7 @@ class PluginManager:
|
||||
json.dump(self.pconf, f, indent=4, ensure_ascii=False)
|
||||
|
||||
def load_config(self):
|
||||
logger.info("Loading plugins config...")
|
||||
logger.debug("Loading plugins config...")
|
||||
|
||||
modified = False
|
||||
if os.path.exists("./plugins/plugins.json"):
|
||||
@@ -85,7 +85,7 @@ class PluginManager:
|
||||
logger.error(e)
|
||||
|
||||
def scan_plugins(self):
|
||||
logger.info("Scaning plugins ...")
|
||||
logger.debug("Scanning plugins ...")
|
||||
plugins_dir = "./plugins"
|
||||
raws = [self.plugins[name] for name in self.plugins]
|
||||
for plugin_name in os.listdir(plugins_dir):
|
||||
|
||||
@@ -66,7 +66,7 @@ class Role(Plugin):
|
||||
raise Exception("no role found")
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
self.roleplays = {}
|
||||
logger.info("[Role] inited")
|
||||
logger.debug("[Role] inited")
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError):
|
||||
logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .tool import *
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user