跳转到主要内容
工具调用钩子在代理操作期间提供对工具执行的细粒度控制。这些钩子允许您拦截工具调用、修改输入、转换输出、实施安全检查以及添加全面的日志记录或监控。

概述

工具钩子在两个关键点执行
  • 工具调用之前:修改输入、验证参数或阻止执行
  • 工具调用之后:转换结果、清理输出或记录执行详情

钩子类型

工具调用前钩子

这些钩子在每次工具执行前执行,可以
  • 检查和修改工具输入
  • 根据条件阻止工具执行
  • 为危险操作实施审批门槛
  • 验证参数
  • 记录工具调用
签名
def before_hook(context: ToolCallHookContext) -> bool | None:
    # Return False to block execution
    # Return True or None to allow execution
    ...

工具调用后钩子

这些钩子在每次工具执行后执行,可以
  • 修改或清理工具结果
  • 添加元数据或格式
  • 记录执行结果
  • 实施结果验证
  • 转换输出格式
签名
def after_hook(context: ToolCallHookContext) -> str | None:
    # Return modified result string
    # Return None to keep original result
    ...

工具钩子上下文

ToolCallHookContext 对象提供对工具执行状态的全面访问
class ToolCallHookContext:
    tool_name: str                    # Name of the tool being called
    tool_input: dict[str, Any]        # Mutable tool input parameters
    tool: CrewStructuredTool          # Tool instance reference
    agent: Agent | BaseAgent | None   # Agent executing the tool
    task: Task | None                 # Current task
    crew: Crew | None                 # Crew instance
    tool_result: str | None           # Tool result (after hooks only)

修改工具输入

重要提示: 始终原地修改工具输入
# ✅ Correct - modify in-place
def sanitize_input(context: ToolCallHookContext) -> None:
    context.tool_input['query'] = context.tool_input['query'].lower()

# ❌ Wrong - replaces dict reference
def wrong_approach(context: ToolCallHookContext) -> None:
    context.tool_input = {'query': 'new query'}

注册方法

1. 全局钩子注册

注册适用于所有团队的所有工具调用的钩子
from crewai.hooks import register_before_tool_call_hook, register_after_tool_call_hook

def log_tool_call(context):
    print(f"Tool: {context.tool_name}")
    print(f"Input: {context.tool_input}")
    return None  # Allow execution

register_before_tool_call_hook(log_tool_call)

2. 基于装饰器的注册

使用装饰器以实现更简洁的语法
from crewai.hooks import before_tool_call, after_tool_call

@before_tool_call
def block_dangerous_tools(context):
    dangerous_tools = ['delete_database', 'drop_table', 'rm_rf']
    if context.tool_name in dangerous_tools:
        print(f"⛔ Blocked dangerous tool: {context.tool_name}")
        return False  # Block execution
    return None

@after_tool_call
def sanitize_results(context):
    if context.tool_result and "password" in context.tool_result.lower():
        return context.tool_result.replace("password", "[REDACTED]")
    return None

3. 团队范围的钩子

为特定团队实例注册钩子
@CrewBase
class MyProjCrew:
    @before_tool_call_crew
    def validate_tool_inputs(self, context):
        # Only applies to this crew
        if context.tool_name == "web_search":
            if not context.tool_input.get('query'):
                print("❌ Invalid search query")
                return False
        return None

    @after_tool_call_crew
    def log_tool_results(self, context):
        # Crew-specific tool logging
        print(f"✅ {context.tool_name} completed")
        return None

    @crew
    def crew(self) -> Crew:
        return Crew(
            agents=self.agents,
            tasks=self.tasks,
            process=Process.sequential,
            verbose=True
        )

常见用例

1. 安全护栏

@before_tool_call
def safety_check(context: ToolCallHookContext) -> bool | None:
    # Block tools that could cause harm
    destructive_tools = [
        'delete_file',
        'drop_table',
        'remove_user',
        'system_shutdown'
    ]

    if context.tool_name in destructive_tools:
        print(f"🛑 Blocked destructive tool: {context.tool_name}")
        return False

    # Warn on sensitive operations
    sensitive_tools = ['send_email', 'post_to_social_media', 'charge_payment']
    if context.tool_name in sensitive_tools:
        print(f"⚠️  Executing sensitive tool: {context.tool_name}")

    return None

2. 人工审批门槛

@before_tool_call
def require_approval_for_actions(context: ToolCallHookContext) -> bool | None:
    approval_required = [
        'send_email',
        'make_purchase',
        'delete_file',
        'post_message'
    ]

    if context.tool_name in approval_required:
        response = context.request_human_input(
            prompt=f"Approve {context.tool_name}?",
            default_message=f"Input: {context.tool_input}\nType 'yes' to approve:"
        )

        if response.lower() != 'yes':
            print(f"❌ Tool execution denied: {context.tool_name}")
            return False

    return None

3. 输入验证和清理

@before_tool_call
def validate_and_sanitize_inputs(context: ToolCallHookContext) -> bool | None:
    # Validate search queries
    if context.tool_name == 'web_search':
        query = context.tool_input.get('query', '')
        if len(query) < 3:
            print("❌ Search query too short")
            return False

        # Sanitize query
        context.tool_input['query'] = query.strip().lower()

    # Validate file paths
    if context.tool_name == 'read_file':
        path = context.tool_input.get('path', '')
        if '..' in path or path.startswith('/'):
            print("❌ Invalid file path")
            return False

    return None

4. 结果清理

@after_tool_call
def sanitize_sensitive_data(context: ToolCallHookContext) -> str | None:
    if not context.tool_result:
        return None

    import re
    result = context.tool_result

    # Remove API keys
    result = re.sub(
        r'(api[_-]?key|token)["\']?\s*[:=]\s*["\']?[\w-]+',
        r'\1: [REDACTED]',
        result,
        flags=re.IGNORECASE
    )

    # Remove email addresses
    result = re.sub(
        r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
        '[EMAIL-REDACTED]',
        result
    )

    # Remove credit card numbers
    result = re.sub(
        r'\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b',
        '[CARD-REDACTED]',
        result
    )

    return result

5. 工具使用分析

import time
from collections import defaultdict

tool_stats = defaultdict(lambda: {'count': 0, 'total_time': 0, 'failures': 0})

@before_tool_call
def start_timer(context: ToolCallHookContext) -> None:
    context.tool_input['_start_time'] = time.time()
    return None

@after_tool_call
def track_tool_usage(context: ToolCallHookContext) -> None:
    start_time = context.tool_input.get('_start_time', time.time())
    duration = time.time() - start_time

    tool_stats[context.tool_name]['count'] += 1
    tool_stats[context.tool_name]['total_time'] += duration

    if not context.tool_result or 'error' in context.tool_result.lower():
        tool_stats[context.tool_name]['failures'] += 1

    print(f"""
    📊 Tool Stats for {context.tool_name}:
    - Executions: {tool_stats[context.tool_name]['count']}
    - Avg Time: {tool_stats[context.tool_name]['total_time'] / tool_stats[context.tool_name]['count']:.2f}s
    - Failures: {tool_stats[context.tool_name]['failures']}
    """)

    return None

6. 速率限制

from collections import defaultdict
from datetime import datetime, timedelta

tool_call_history = defaultdict(list)

@before_tool_call
def rate_limit_tools(context: ToolCallHookContext) -> bool | None:
    tool_name = context.tool_name
    now = datetime.now()

    # Clean old entries (older than 1 minute)
    tool_call_history[tool_name] = [
        call_time for call_time in tool_call_history[tool_name]
        if now - call_time < timedelta(minutes=1)
    ]

    # Check rate limit (max 10 calls per minute)
    if len(tool_call_history[tool_name]) >= 10:
        print(f"🚫 Rate limit exceeded for {tool_name}")
        return False

    # Record this call
    tool_call_history[tool_name].append(now)
    return None

7. 缓存工具结果

import hashlib
import json

tool_cache = {}

def cache_key(tool_name: str, tool_input: dict) -> str:
    """Generate cache key from tool name and input."""
    input_str = json.dumps(tool_input, sort_keys=True)
    return hashlib.md5(f"{tool_name}:{input_str}".encode()).hexdigest()

@before_tool_call
def check_cache(context: ToolCallHookContext) -> bool | None:
    key = cache_key(context.tool_name, context.tool_input)
    if key in tool_cache:
        print(f"💾 Cache hit for {context.tool_name}")
        # Note: Can't return cached result from before hook
        # Would need to implement this differently
    return None

@after_tool_call
def cache_result(context: ToolCallHookContext) -> None:
    if context.tool_result:
        key = cache_key(context.tool_name, context.tool_input)
        tool_cache[key] = context.tool_result
        print(f"💾 Cached result for {context.tool_name}")
    return None

8. 调试日志记录

@before_tool_call
def debug_tool_call(context: ToolCallHookContext) -> None:
    print(f"""
    🔍 Tool Call Debug:
    - Tool: {context.tool_name}
    - Agent: {context.agent.role if context.agent else 'Unknown'}
    - Task: {context.task.description[:50] if context.task else 'Unknown'}...
    - Input: {context.tool_input}
    """)
    return None

@after_tool_call
def debug_tool_result(context: ToolCallHookContext) -> None:
    if context.tool_result:
        result_preview = context.tool_result[:200]
        print(f"✅ Result Preview: {result_preview}...")
    else:
        print("⚠️  No result returned")
    return None

钩子管理

注销钩子

from crewai.hooks import (
    unregister_before_tool_call_hook,
    unregister_after_tool_call_hook
)

# Unregister specific hook
def my_hook(context):
    ...

register_before_tool_call_hook(my_hook)
# Later...
success = unregister_before_tool_call_hook(my_hook)
print(f"Unregistered: {success}")

清除钩子

from crewai.hooks import (
    clear_before_tool_call_hooks,
    clear_after_tool_call_hooks,
    clear_all_tool_call_hooks
)

# Clear specific hook type
count = clear_before_tool_call_hooks()
print(f"Cleared {count} before hooks")

# Clear all tool hooks
before_count, after_count = clear_all_tool_call_hooks()
print(f"Cleared {before_count} before and {after_count} after hooks")

列出已注册的钩子

from crewai.hooks import (
    get_before_tool_call_hooks,
    get_after_tool_call_hooks
)

# Get current hooks
before_hooks = get_before_tool_call_hooks()
after_hooks = get_after_tool_call_hooks()

print(f"Registered: {len(before_hooks)} before, {len(after_hooks)} after")

高级模式

条件钩子执行

@before_tool_call
def conditional_blocking(context: ToolCallHookContext) -> bool | None:
    # Only block for specific agents
    if context.agent and context.agent.role == "junior_agent":
        if context.tool_name in ['delete_file', 'send_email']:
            print(f"❌ Junior agents cannot use {context.tool_name}")
            return False

    # Only block during specific tasks
    if context.task and "sensitive" in context.task.description.lower():
        if context.tool_name == 'web_search':
            print("❌ Web search blocked for sensitive tasks")
            return False

    return None

上下文感知输入修改

@before_tool_call
def enhance_tool_inputs(context: ToolCallHookContext) -> None:
    # Add context based on agent role
    if context.agent and context.agent.role == "researcher":
        if context.tool_name == 'web_search':
            # Add domain restrictions for researchers
            context.tool_input['domains'] = ['edu', 'gov', 'org']

    # Add context based on task
    if context.task and "urgent" in context.task.description.lower():
        if context.tool_name == 'send_email':
            context.tool_input['priority'] = 'high'

    return None

工具链监控

tool_call_chain = []

@before_tool_call
def track_tool_chain(context: ToolCallHookContext) -> None:
    tool_call_chain.append({
        'tool': context.tool_name,
        'timestamp': time.time(),
        'agent': context.agent.role if context.agent else 'Unknown'
    })

    # Detect potential infinite loops
    recent_calls = tool_call_chain[-5:]
    if len(recent_calls) == 5 and all(c['tool'] == context.tool_name for c in recent_calls):
        print(f"⚠️  Warning: {context.tool_name} called 5 times in a row")

    return None

最佳实践

  1. 保持钩子专注:每个钩子应只有一个职责
  2. 避免繁重计算:钩子在每次工具调用时执行
  3. 优雅地处理错误:使用 try-except 防止钩子失败
  4. 使用类型提示:利用 ToolCallHookContext 以获得更好的 IDE 支持
  5. 记录阻止条件:明确何时/为何阻止工具
  6. 独立测试钩子:在生产中使用之前对钩子进行单元测试
  7. 在测试中清除钩子:在测试运行之间使用 clear_all_tool_call_hooks()
  8. 原地修改:始终原地修改 context.tool_input,切勿替换
  9. 记录重要决策:尤其是在阻止工具执行时
  10. 考虑性能:尽可能缓存昂贵的验证

错误处理

@before_tool_call
def safe_validation(context: ToolCallHookContext) -> bool | None:
    try:
        # Your validation logic
        if not validate_input(context.tool_input):
            return False
    except Exception as e:
        print(f"⚠️ Hook error: {e}")
        # Decide: allow or block on error
        return None  # Allow execution despite error

类型安全

from crewai.hooks import ToolCallHookContext, BeforeToolCallHookType, AfterToolCallHookType

# Explicit type annotations
def my_before_hook(context: ToolCallHookContext) -> bool | None:
    return None

def my_after_hook(context: ToolCallHookContext) -> str | None:
    return None

# Type-safe registration
register_before_tool_call_hook(my_before_hook)
register_after_tool_call_hook(my_after_hook)

与现有工具集成

包装现有验证

def existing_validator(tool_name: str, inputs: dict) -> bool:
    """Your existing validation function."""
    # Your validation logic
    return True

@before_tool_call
def integrate_validator(context: ToolCallHookContext) -> bool | None:
    if not existing_validator(context.tool_name, context.tool_input):
        print(f"❌ Validation failed for {context.tool_name}")
        return False
    return None

日志记录到外部系统

import logging

logger = logging.getLogger(__name__)

@before_tool_call
def log_to_external_system(context: ToolCallHookContext) -> None:
    logger.info(f"Tool call: {context.tool_name}", extra={
        'tool_name': context.tool_name,
        'tool_input': context.tool_input,
        'agent': context.agent.role if context.agent else None
    })
    return None

故障排除

钩子未执行

  • 验证钩子是否在团队执行前注册
  • 检查之前的钩子是否返回 False(阻止执行和后续钩子)
  • 确保钩子签名与预期类型匹配

输入修改无效

  • 使用原地修改:context.tool_input['key'] = value
  • 不要替换字典:context.tool_input = {}

结果修改无效

  • 从后钩子返回修改后的字符串
  • 返回 None 将保留原始结果
  • 确保工具实际返回了结果

工具意外被阻止

  • 检查所有前置钩子是否存在阻止条件
  • 验证钩子执行顺序
  • 添加调试日志以识别是哪个钩子在阻止

结论

工具调用钩子为控制和监控 CrewAI 中的工具执行提供了强大的功能。使用它们来实现安全护栏、审批门槛、输入验证、结果清理、日志记录和分析。结合适当的错误处理和类型安全,钩子能够实现安全且生产就绪的代理系统,并具有全面的可观察性。