跳转到主要内容

概述

CrewAI 通过 BaseLLM 抽象基类支持自定义 LLM 实现。这允许您集成任何 LiteLLM 中没有内置支持的 LLM 提供商,或实现自定义身份验证机制。

快速入门

这是一个最精简的自定义 LLM 实现
from crewai import BaseLLM
from typing import Any, Dict, List, Optional, Union
import requests

class CustomLLM(BaseLLM):
    def __init__(self, model: str, api_key: str, endpoint: str, temperature: Optional[float] = None):
        # IMPORTANT: Call super().__init__() with required parameters
        super().__init__(model=model, temperature=temperature)
        
        self.api_key = api_key
        self.endpoint = endpoint
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        """Call the LLM with the given messages."""
        # Convert string to message format if needed
        if isinstance(messages, str):
            messages = [{"role": "user", "content": messages}]
        
        # Prepare request
        payload = {
            "model": self.model,
            "messages": messages,
            "temperature": self.temperature,
        }
        
        # Add tools if provided and supported
        if tools and self.supports_function_calling():
            payload["tools"] = tools
        
        # Make API call
        response = requests.post(
            self.endpoint,
            headers={
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json"
            },
            json=payload,
            timeout=30
        )
        response.raise_for_status()
        
        result = response.json()
        return result["choices"][0]["message"]["content"]
        
    def supports_function_calling(self) -> bool:
        """Override if your LLM supports function calling."""
        return True  # Change to False if your LLM doesn't support tools
        
    def get_context_window_size(self) -> int:
        """Return the context window size of your LLM."""
        return 8192  # Adjust based on your model's actual context window

使用您的自定义 LLM

from crewai import Agent, Task, Crew

# Assuming you have the CustomLLM class defined above
# Create your custom LLM
custom_llm = CustomLLM(
    model="my-custom-model",
    api_key="your-api-key",
    endpoint="https://api.example.com/v1/chat/completions",
    temperature=0.7
)

# Use with an agent
agent = Agent(
    role="Research Assistant",
    goal="Find and analyze information",
    backstory="You are a research assistant.",
    llm=custom_llm
)

# Create and execute tasks
task = Task(
    description="Research the latest developments in AI",
    expected_output="A comprehensive summary",
    agent=agent
)

crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()

必需方法

构造函数:__init__()

关键:您必须使用所需参数调用 super().__init__(model, temperature)
def __init__(self, model: str, api_key: str, temperature: Optional[float] = None):
    # REQUIRED: Call parent constructor with model and temperature
    super().__init__(model=model, temperature=temperature)
    
    # Your custom initialization
    self.api_key = api_key

抽象方法:call()

call() 方法是您 LLM 实现的核心。它必须
  • 接受消息(字符串或包含 ‘role’ 和 ‘content’ 的字典列表)
  • 返回一个字符串响应
  • 如果支持,处理工具和函数调用
  • 针对错误抛出适当的异常

可选方法

def supports_function_calling(self) -> bool:
    """Return True if your LLM supports function calling."""
    return True  # Default is True

def supports_stop_words(self) -> bool:
    """Return True if your LLM supports stop sequences."""
    return True  # Default is True

def get_context_window_size(self) -> int:
    """Return the context window size."""
    return 4096  # Default is 4096

常见模式

错误处理

import requests

def call(self, messages, tools=None, callbacks=None, available_functions=None):
    try:
        response = requests.post(
            self.endpoint,
            headers={"Authorization": f"Bearer {self.api_key}"},
            json=payload,
            timeout=30
        )
        response.raise_for_status()
        return response.json()["choices"][0]["message"]["content"]
        
    except requests.Timeout:
        raise TimeoutError("LLM request timed out")
    except requests.RequestException as e:
        raise RuntimeError(f"LLM request failed: {str(e)}")
    except (KeyError, IndexError) as e:
        raise ValueError(f"Invalid response format: {str(e)}")

自定义身份验证

from crewai import BaseLLM
from typing import Optional

class CustomAuthLLM(BaseLLM):
    def __init__(self, model: str, auth_token: str, endpoint: str, temperature: Optional[float] = None):
        super().__init__(model=model, temperature=temperature)
        self.auth_token = auth_token
        self.endpoint = endpoint
    
    def call(self, messages, tools=None, callbacks=None, available_functions=None):
        headers = {
            "Authorization": f"Custom {self.auth_token}",  # Custom auth format
            "Content-Type": "application/json"
        }
        # Rest of implementation...

停用词支持

CrewAI 会自动添加 "\nObservation:" 作为停用词来控制智能体行为。如果您的 LLM 支持停用词
def call(self, messages, tools=None, callbacks=None, available_functions=None):
    payload = {
        "model": self.model,
        "messages": messages,
        "stop": self.stop  # Include stop words in API call
    }
    # Make API call...

def supports_stop_words(self) -> bool:
    return True  # Your LLM supports stop sequences
如果您的 LLM 不原生支持停用词
def call(self, messages, tools=None, callbacks=None, available_functions=None):
    response = self._make_api_call(messages, tools)
    content = response["choices"][0]["message"]["content"]
    
    # Manually truncate at stop words
    if self.stop:
        for stop_word in self.stop:
            if stop_word in content:
                content = content.split(stop_word)[0]
                break
    
    return content

def supports_stop_words(self) -> bool:
    return False  # Tell CrewAI we handle stop words manually

函数调用

如果您的 LLM 支持函数调用,请实现完整的流程
import json

def call(self, messages, tools=None, callbacks=None, available_functions=None):
    # Convert string to message format
    if isinstance(messages, str):
        messages = [{"role": "user", "content": messages}]
    
    # Make API call
    response = self._make_api_call(messages, tools)
    message = response["choices"][0]["message"]
    
    # Check for function calls
    if "tool_calls" in message and available_functions:
        return self._handle_function_calls(
            message["tool_calls"], messages, tools, available_functions
        )
    
    return message["content"]

def _handle_function_calls(self, tool_calls, messages, tools, available_functions):
    """Handle function calling with proper message flow."""
    for tool_call in tool_calls:
        function_name = tool_call["function"]["name"]
        
        if function_name in available_functions:
            # Parse and execute function
            function_args = json.loads(tool_call["function"]["arguments"])
            function_result = available_functions[function_name](**function_args)
            
            # Add function call and result to message history
            messages.append({
                "role": "assistant",
                "content": None,
                "tool_calls": [tool_call]
            })
            messages.append({
                "role": "tool",
                "tool_call_id": tool_call["id"],
                "name": function_name,
                "content": str(function_result)
            })
            
            # Call LLM again with updated context
            return self.call(messages, tools, None, available_functions)
    
    return "Function call failed"

故障排除

常见问题

构造函数错误
# ❌ Wrong - missing required parameters
def __init__(self, api_key: str):
    super().__init__()

# ✅ Correct
def __init__(self, model: str, api_key: str, temperature: Optional[float] = None):
    super().__init__(model=model, temperature=temperature)
函数调用不工作
  • 确保 supports_function_calling() 返回 True
  • 检查您是否在响应中处理了 tool_calls
  • 验证 available_functions 参数是否使用正确
身份验证失败
  • 验证 API 密钥格式和权限
  • 检查身份验证头格式
  • 确保端点 URL 正确
响应解析错误
  • 在访问嵌套字段前验证响应结构
  • 处理内容可能为 None 的情况
  • 为格式错误的响应添加适当的错误处理

测试您的自定义 LLM

from crewai import Agent, Task, Crew

def test_custom_llm():
    llm = CustomLLM(
        model="test-model",
        api_key="test-key",
        endpoint="https://api.test.com"
    )
    
    # Test basic call
    result = llm.call("Hello, world!")
    assert isinstance(result, str)
    assert len(result) > 0
    
    # Test with CrewAI agent
    agent = Agent(
        role="Test Agent",
        goal="Test custom LLM",
        backstory="A test agent.",
        llm=llm
    )
    
    task = Task(
        description="Say hello",
        expected_output="A greeting",
        agent=agent
    )
    
    crew = Crew(agents=[agent], tasks=[task])
    result = crew.kickoff()
    assert "hello" in result.raw.lower()
本指南涵盖了在 CrewAI 中实现自定义 LLM 的基本要素。