自定义 LLM 实现

CrewAI 现在通过 BaseLLM 抽象基类支持自定义 LLM 实现。这使您可以创建不依赖 litellm 认证机制的自有 LLM 实现。

要创建自定义 LLM 实现,您需要:

  1. 继承 BaseLLM 抽象基类
  2. 实现所需方法
    • call():使用消息调用 LLM 的主要方法
    • supports_function_calling():LLM 是否支持函数调用
    • supports_stop_words():LLM 是否支持停止词
    • get_context_window_size():LLM 的上下文窗口大小

示例:基本自定义 LLM

from crewai import BaseLLM
from typing import Any, Dict, List, Optional, Union

class CustomLLM(BaseLLM):
    def __init__(self, api_key: str, endpoint: str):
        super().__init__()  # Initialize the base class to set default attributes
        if not api_key or not isinstance(api_key, str):
            raise ValueError("Invalid API key: must be a non-empty string")
        if not endpoint or not isinstance(endpoint, str):
            raise ValueError("Invalid endpoint URL: must be a non-empty string")
        self.api_key = api_key
        self.endpoint = endpoint
        self.stop = []  # You can customize stop words if needed
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        """Call the LLM with the given messages.
        
        Args:
            messages: Input messages for the LLM.
            tools: Optional list of tool schemas for function calling.
            callbacks: Optional list of callback functions.
            available_functions: Optional dict mapping function names to callables.
            
        Returns:
            Either a text response from the LLM or the result of a tool function call.
            
        Raises:
            TimeoutError: If the LLM request times out.
            RuntimeError: If the LLM request fails for other reasons.
            ValueError: If the response format is invalid.
        """
        # Implement your own logic to call the LLM
        # For example, using requests:
        import requests
        
        try:
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json"
            }
            
            # Convert string message to proper format if needed
            if isinstance(messages, str):
                messages = [{"role": "user", "content": messages}]
            
            data = {
                "messages": messages,
                "tools": tools
            }
            
            response = requests.post(
                self.endpoint, 
                headers=headers, 
                json=data,
                timeout=30  # Set a reasonable timeout
            )
            response.raise_for_status()  # Raise an exception for HTTP errors
            return response.json()["choices"][0]["message"]["content"]
        except requests.Timeout:
            raise TimeoutError("LLM request timed out")
        except requests.RequestException as e:
            raise RuntimeError(f"LLM request failed: {str(e)}")
        except (KeyError, IndexError, ValueError) as e:
            raise ValueError(f"Invalid response format: {str(e)}")
        
    def supports_function_calling(self) -> bool:
        """Check if the LLM supports function calling.
        
        Returns:
            True if the LLM supports function calling, False otherwise.
        """
        # Return True if your LLM supports function calling
        return True
        
    def supports_stop_words(self) -> bool:
        """Check if the LLM supports stop words.
        
        Returns:
            True if the LLM supports stop words, False otherwise.
        """
        # Return True if your LLM supports stop words
        return True
        
    def get_context_window_size(self) -> int:
        """Get the context window size of the LLM.
        
        Returns:
            The context window size as an integer.
        """
        # Return the context window size of your LLM
        return 8192

错误处理最佳实践

实现自定义 LLM 时,正确处理错误以确保鲁棒性和可靠性非常重要。以下是一些最佳实践:

1. 为 API 调用实现 Try-Except 块

始终将 API 调用包装在 try-except 块中以处理不同类型的错误

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    try:
        # API call implementation
        response = requests.post(
            self.endpoint,
            headers=self.headers,
            json=self.prepare_payload(messages),
            timeout=30  # Set a reasonable timeout
        )
        response.raise_for_status()  # Raise an exception for HTTP errors
        return response.json()["choices"][0]["message"]["content"]
    except requests.Timeout:
        raise TimeoutError("LLM request timed out")
    except requests.RequestException as e:
        raise RuntimeError(f"LLM request failed: {str(e)}")
    except (KeyError, IndexError, ValueError) as e:
        raise ValueError(f"Invalid response format: {str(e)}")

2. 为瞬时故障实现重试逻辑

对于网络问题或速率限制等瞬时故障,实现带有指数退避的重试逻辑

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    import time
    
    max_retries = 3
    retry_delay = 1  # seconds
    
    for attempt in range(max_retries):
        try:
            response = requests.post(
                self.endpoint,
                headers=self.headers,
                json=self.prepare_payload(messages),
                timeout=30
            )
            response.raise_for_status()
            return response.json()["choices"][0]["message"]["content"]
        except (requests.Timeout, requests.ConnectionError) as e:
            if attempt < max_retries - 1:
                time.sleep(retry_delay * (2 ** attempt))  # Exponential backoff
                continue
            raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}")
        except requests.RequestException as e:
            raise RuntimeError(f"LLM request failed: {str(e)}")

3. 验证输入参数

始终验证输入参数以防止运行时错误

def __init__(self, api_key: str, endpoint: str):
    super().__init__()
    if not api_key or not isinstance(api_key, str):
        raise ValueError("Invalid API key: must be a non-empty string")
    if not endpoint or not isinstance(endpoint, str):
        raise ValueError("Invalid endpoint URL: must be a non-empty string")
    self.api_key = api_key
    self.endpoint = endpoint

4. 优雅地处理认证错误

为认证失败提供清晰的错误消息

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    try:
        response = requests.post(self.endpoint, headers=self.headers, json=data)
        if response.status_code == 401:
            raise ValueError("Authentication failed: Invalid API key or token")
        elif response.status_code == 403:
            raise ValueError("Authorization failed: Insufficient permissions")
        response.raise_for_status()
        # Process response
    except Exception as e:
        # Handle error
        raise

示例:基于 JWT 的认证

对于使用基于 JWT 的认证而不是 API 密钥的服务,您可以像这样实现一个自定义 LLM:

from crewai import BaseLLM, Agent, Task
from typing import Any, Dict, List, Optional, Union

class JWTAuthLLM(BaseLLM):
    def __init__(self, jwt_token: str, endpoint: str):
        super().__init__()  # Initialize the base class to set default attributes
        if not jwt_token or not isinstance(jwt_token, str):
            raise ValueError("Invalid JWT token: must be a non-empty string")
        if not endpoint or not isinstance(endpoint, str):
            raise ValueError("Invalid endpoint URL: must be a non-empty string")
        self.jwt_token = jwt_token
        self.endpoint = endpoint
        self.stop = []  # You can customize stop words if needed
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        """Call the LLM with JWT authentication.
        
        Args:
            messages: Input messages for the LLM.
            tools: Optional list of tool schemas for function calling.
            callbacks: Optional list of callback functions.
            available_functions: Optional dict mapping function names to callables.
            
        Returns:
            Either a text response from the LLM or the result of a tool function call.
            
        Raises:
            TimeoutError: If the LLM request times out.
            RuntimeError: If the LLM request fails for other reasons.
            ValueError: If the response format is invalid.
        """
        # Implement your own logic to call the LLM with JWT authentication
        import requests
        
        try:
            headers = {
                "Authorization": f"Bearer {self.jwt_token}",
                "Content-Type": "application/json"
            }
            
            # Convert string message to proper format if needed
            if isinstance(messages, str):
                messages = [{"role": "user", "content": messages}]
            
            data = {
                "messages": messages,
                "tools": tools
            }
            
            response = requests.post(
                self.endpoint,
                headers=headers,
                json=data,
                timeout=30  # Set a reasonable timeout
            )
            
            if response.status_code == 401:
                raise ValueError("Authentication failed: Invalid JWT token")
            elif response.status_code == 403:
                raise ValueError("Authorization failed: Insufficient permissions")
                
            response.raise_for_status()  # Raise an exception for HTTP errors
            return response.json()["choices"][0]["message"]["content"]
        except requests.Timeout:
            raise TimeoutError("LLM request timed out")
        except requests.RequestException as e:
            raise RuntimeError(f"LLM request failed: {str(e)}")
        except (KeyError, IndexError, ValueError) as e:
            raise ValueError(f"Invalid response format: {str(e)}")
        
    def supports_function_calling(self) -> bool:
        """Check if the LLM supports function calling.
        
        Returns:
            True if the LLM supports function calling, False otherwise.
        """
        return True
        
    def supports_stop_words(self) -> bool:
        """Check if the LLM supports stop words.
        
        Returns:
            True if the LLM supports stop words, False otherwise.
        """
        return True
        
    def get_context_window_size(self) -> int:
        """Get the context window size of the LLM.
        
        Returns:
            The context window size as an integer.
        """
        return 8192

故障排除

以下是您在实现自定义 LLM 时可能遇到的一些常见问题以及如何解决它们:

1. 认证失败

症状:401 Unauthorized 或 403 Forbidden 错误

解决方案:

  • 验证您的 API 密钥或 JWT token 是否有效且未过期
  • 检查您是否使用了正确的认证 header 格式
  • 确保您的 token 具有必要的权限

2. 超时问题

症状:请求耗时过长或超时

解决方案:

  • 按照示例所示实现超时处理
  • 使用带有指数退避的重试逻辑
  • 考虑使用更可靠的网络连接

3. 响应解析错误

症状:处理响应时出现 KeyError、IndexError 或 ValueError

解决方案:

  • 在访问嵌套字段之前验证响应格式
  • 为格式错误的响应实现适当的错误处理
  • 查看 API 文档以了解预期的响应格式

4. 速率限制

症状:429 Too Many Requests 错误

解决方案:

  • 在您的自定义 LLM 中实现速率限制
  • 为重试添加指数退避
  • 考虑使用令牌桶算法进行更精确的速率控制

高级功能

日志记录

在您的自定义 LLM 中添加日志记录有助于调试和监控

import logging
from typing import Any, Dict, List, Optional, Union

class LoggingLLM(BaseLLM):
    def __init__(self, api_key: str, endpoint: str):
        super().__init__()
        self.api_key = api_key
        self.endpoint = endpoint
        self.logger = logging.getLogger("crewai.llm.custom")
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages")
        try:
            # API call implementation
            response = self._make_api_call(messages, tools)
            self.logger.debug(f"LLM response received: {response[:100]}...")
            return response
        except Exception as e:
            self.logger.error(f"LLM call failed: {str(e)}")
            raise

速率限制

实现速率限制有助于避免使 LLM API 过载

import time
from typing import Any, Dict, List, Optional, Union

class RateLimitedLLM(BaseLLM):
    def __init__(
        self, 
        api_key: str, 
        endpoint: str, 
        requests_per_minute: int = 60
    ):
        super().__init__()
        self.api_key = api_key
        self.endpoint = endpoint
        self.requests_per_minute = requests_per_minute
        self.request_times: List[float] = []
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        self._enforce_rate_limit()
        # Record this request time
        self.request_times.append(time.time())
        # Make the actual API call
        return self._make_api_call(messages, tools)
        
    def _enforce_rate_limit(self) -> None:
        """Enforce the rate limit by waiting if necessary."""
        now = time.time()
        # Remove request times older than 1 minute
        self.request_times = [t for t in self.request_times if now - t < 60]
        
        if len(self.request_times) >= self.requests_per_minute:
            # Calculate how long to wait
            oldest_request = min(self.request_times)
            wait_time = 60 - (now - oldest_request)
            if wait_time > 0:
                time.sleep(wait_time)

指标收集

收集指标有助于您监控 LLM 使用情况

import time
from typing import Any, Dict, List, Optional, Union

class MetricsCollectingLLM(BaseLLM):
    def __init__(self, api_key: str, endpoint: str):
        super().__init__()
        self.api_key = api_key
        self.endpoint = endpoint
        self.metrics: Dict[str, Any] = {
            "total_calls": 0,
            "total_tokens": 0,
            "errors": 0,
            "latency": []
        }
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        start_time = time.time()
        self.metrics["total_calls"] += 1
        
        try:
            response = self._make_api_call(messages, tools)
            # Estimate tokens (simplified)
            if isinstance(messages, str):
                token_estimate = len(messages) // 4
            else:
                token_estimate = sum(len(m.get("content", "")) // 4 for m in messages)
            self.metrics["total_tokens"] += token_estimate
            return response
        except Exception as e:
            self.metrics["errors"] += 1
            raise
        finally:
            latency = time.time() - start_time
            self.metrics["latency"].append(latency)
            
    def get_metrics(self) -> Dict[str, Any]:
        """Return the collected metrics."""
        avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0
        return {
            **self.metrics,
            "avg_latency": avg_latency
        }

高级用法:函数调用

如果您的 LLM 支持函数调用,您可以在自定义 LLM 中实现函数调用逻辑

import json
from typing import Any, Dict, List, Optional, Union

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    import requests
    
    try:
        headers = {
            "Authorization": f"Bearer {self.jwt_token}",
            "Content-Type": "application/json"
        }
        
        # Convert string message to proper format if needed
        if isinstance(messages, str):
            messages = [{"role": "user", "content": messages}]
        
        data = {
            "messages": messages,
            "tools": tools
        }
        
        response = requests.post(
            self.endpoint,
            headers=headers,
            json=data,
            timeout=30
        )
        response.raise_for_status()
        response_data = response.json()
        
        # Check if the LLM wants to call a function
        if response_data["choices"][0]["message"].get("tool_calls"):
            tool_calls = response_data["choices"][0]["message"]["tool_calls"]
            
            # Process each tool call
            for tool_call in tool_calls:
                function_name = tool_call["function"]["name"]
                function_args = json.loads(tool_call["function"]["arguments"])
                
                if available_functions and function_name in available_functions:
                    function_to_call = available_functions[function_name]
                    function_response = function_to_call(**function_args)
                    
                    # Add the function response to the messages
                    messages.append({
                        "role": "tool",
                        "tool_call_id": tool_call["id"],
                        "name": function_name,
                        "content": str(function_response)
                    })
            
            # Call the LLM again with the updated messages
            return self.call(messages, tools, callbacks, available_functions)
        
        # Return the text response if no function call
        return response_data["choices"][0]["message"]["content"]
    except requests.Timeout:
        raise TimeoutError("LLM request timed out")
    except requests.RequestException as e:
        raise RuntimeError(f"LLM request failed: {str(e)}")
    except (KeyError, IndexError, ValueError) as e:
        raise ValueError(f"Invalid response format: {str(e)}")

在 CrewAI 中使用您的自定义 LLM

实现自定义 LLM 后,您可以将其与 CrewAI 智能体和团队一起使用

from crewai import Agent, Task, Crew
from typing import Dict, Any

# Create your custom LLM instance
jwt_llm = JWTAuthLLM(
    jwt_token="your.jwt.token", 
    endpoint="https://your-llm-endpoint.com/v1/chat/completions"
)

# Use it with an agent
agent = Agent(
    role="Research Assistant",
    goal="Find information on a topic",
    backstory="You are a research assistant tasked with finding information.",
    llm=jwt_llm,
)

# Create a task for the agent
task = Task(
    description="Research the benefits of exercise",
    agent=agent,
    expected_output="A summary of the benefits of exercise",
)

# Execute the task
result = agent.execute_task(task)
print(result)

# Or use it with a crew
crew = Crew(
    agents=[agent],
    tasks=[task],
    manager_llm=jwt_llm,  # Use your custom LLM for the manager
)

# Run the crew
result = crew.kickoff()
print(result)

实现您自己的认证机制

BaseLLM 类允许您实现所需的任何认证机制,而不仅仅是 JWT 或 API 密钥。您可以使用:

  • OAuth token
  • 客户端证书
  • 自定义 headers
  • 基于会话的认证
  • 您的 LLM 提供商要求的任何其他认证方法

只需在您的自定义 LLM 类中实现适当的认证逻辑即可。