Skip to content

FastAPI 依赖注入

概述

依赖注入(Dependency Injection)是FastAPI的核心特性之一,它提供了一种优雅的方式来管理应用程序的依赖关系。通过依赖注入,我们可以实现代码解耦、提高可测试性、增强可维护性。本章将深入探讨FastAPI的依赖注入系统。

🔧 依赖注入基础

简单依赖

python
from fastapi import FastAPI, Depends
from typing import Optional

app = FastAPI()

def get_current_time():
    """简单的依赖函数"""
    from datetime import datetime
    return datetime.now()

def get_user_agent(user_agent: Optional[str] = Header(None)):
    """从请求头获取User-Agent"""
    return user_agent or "Unknown"

@app.get("/info/")
async def get_info(
    current_time: datetime = Depends(get_current_time),
    user_agent: str = Depends(get_user_agent)
):
    return {
        "current_time": current_time.isoformat(),
        "user_agent": user_agent,
        "message": "依赖注入示例"
    }

# 依赖也可以是异步的
async def get_database_info():
    """模拟异步数据库查询"""
    await asyncio.sleep(0.1)  # 模拟数据库延迟
    return {"database": "PostgreSQL", "version": "13.7", "status": "connected"}

@app.get("/db-info/")
async def get_db_info(db_info: dict = Depends(get_database_info)):
    return db_info

带参数的依赖

python
def get_query_extractor(query_param: str = "q", default_value: str = ""):
    """创建查询参数提取器的工厂函数"""
    def query_extractor(request: Request):
        return request.query_params.get(query_param, default_value)
    return query_extractor

def get_pagination_params(page: int = Query(1, ge=1), size: int = Query(10, ge=1, le=100)):
    """分页参数依赖"""
    return {
        "page": page,
        "size": size,
        "offset": (page - 1) * size
    }

def get_sorting_params(
    sort_by: str = Query("created_at", description="排序字段"),
    sort_order: str = Query("desc", regex="^(asc|desc)$", description="排序方向")
):
    """排序参数依赖"""
    return {"sort_by": sort_by, "sort_order": sort_order}

@app.get("/posts/")
async def list_posts(
    pagination: dict = Depends(get_pagination_params),
    sorting: dict = Depends(get_sorting_params),
    search_query: str = Depends(get_query_extractor("search", ""))
):
    return {
        "pagination": pagination,
        "sorting": sorting,
        "search_query": search_query,
        "posts": []  # 模拟帖子列表
    }

🏗️ 类作为依赖

简单类依赖

python
class DatabaseConnection:
    def __init__(self):
        self.host = "localhost"
        self.port = 5432
        self.database = "myapp"
        self.connected = False
    
    async def connect(self):
        """模拟数据库连接"""
        await asyncio.sleep(0.1)
        self.connected = True
        return self
    
    async def disconnect(self):
        """模拟断开连接"""
        self.connected = False
    
    async def execute_query(self, query: str):
        """模拟执行查询"""
        if not self.connected:
            raise Exception("数据库未连接")
        await asyncio.sleep(0.05)
        return f"执行查询: {query}"

class UserService:
    def __init__(self, db: DatabaseConnection = Depends()):
        self.db = db
    
    async def get_user_by_id(self, user_id: int):
        """根据ID获取用户"""
        await self.db.connect()
        result = await self.db.execute_query(f"SELECT * FROM users WHERE id = {user_id}")
        await self.db.disconnect()
        return {"user_id": user_id, "name": f"User {user_id}", "query_result": result}
    
    async def create_user(self, user_data: dict):
        """创建用户"""
        await self.db.connect()
        result = await self.db.execute_query(f"INSERT INTO users VALUES (...)")
        await self.db.disconnect()
        return {"message": "用户创建成功", "user_data": user_data}

@app.get("/users/{user_id}")
async def get_user(user_id: int, user_service: UserService = Depends()):
    return await user_service.get_user_by_id(user_id)

@app.post("/users/")
async def create_user(user_data: dict, user_service: UserService = Depends()):
    return await user_service.create_user(user_data)

具有配置的类依赖

python
from pydantic import BaseSettings
import httpx

class APISettings(BaseSettings):
    external_api_url: str = "https://api.example.com"
    api_key: str = "default-key"
    timeout: int = 30
    max_retries: int = 3
    
    class Config:
        env_file = ".env"

class ExternalAPIClient:
    def __init__(self, settings: APISettings = Depends()):
        self.settings = settings
        self.client = None
    
    async def __aenter__(self):
        """异步上下文管理器进入"""
        self.client = httpx.AsyncClient(
            base_url=self.settings.external_api_url,
            timeout=self.settings.timeout,
            headers={"Authorization": f"Bearer {self.settings.api_key}"}
        )
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """异步上下文管理器退出"""
        if self.client:
            await self.client.aclose()
    
    async def get_data(self, endpoint: str):
        """获取外部API数据"""
        if not self.client:
            raise Exception("客户端未初始化")
        
        for attempt in range(self.settings.max_retries):
            try:
                response = await self.client.get(endpoint)
                response.raise_for_status()
                return response.json()
            except httpx.RequestError as e:
                if attempt == self.settings.max_retries - 1:
                    raise
                await asyncio.sleep(2 ** attempt)  # 指数退避

async def get_api_client():
    """API客户端依赖提供器"""
    settings = APISettings()
    async with ExternalAPIClient(settings) as client:
        yield client

@app.get("/external-data/{endpoint}")
async def get_external_data(
    endpoint: str,
    api_client: ExternalAPIClient = Depends(get_api_client)
):
    try:
        data = await api_client.get_data(endpoint)
        return {"data": data, "source": "external_api"}
    except Exception as e:
        raise HTTPException(status_code=503, detail=f"外部API调用失败: {str(e)}")

🔄 依赖的子依赖

多层依赖

python
class ConfigService:
    def __init__(self):
        self.config = {
            "app_name": "MyApp",
            "version": "1.0.0",
            "debug": True,
            "database_url": "postgresql://localhost/myapp"
        }
    
    def get(self, key: str, default=None):
        return self.config.get(key, default)

class LoggingService:
    def __init__(self, config: ConfigService = Depends()):
        self.config = config
        self.app_name = config.get("app_name")
        self.debug = config.get("debug")
    
    def log_info(self, message: str):
        level = "DEBUG" if self.debug else "INFO"
        print(f"[{level}] {self.app_name}: {message}")
    
    def log_error(self, message: str):
        print(f"[ERROR] {self.app_name}: {message}")

class CacheService:
    def __init__(self, config: ConfigService = Depends(), logger: LoggingService = Depends()):
        self.config = config
        self.logger = logger
        self.cache = {}
        self.logger.log_info("缓存服务初始化")
    
    def get(self, key: str):
        value = self.cache.get(key)
        if value:
            self.logger.log_info(f"缓存命中: {key}")
        else:
            self.logger.log_info(f"缓存未命中: {key}")
        return value
    
    def set(self, key: str, value, ttl: int = 300):
        self.cache[key] = value
        self.logger.log_info(f"缓存设置: {key}")

class BusinessService:
    def __init__(
        self,
        config: ConfigService = Depends(),
        logger: LoggingService = Depends(),
        cache: CacheService = Depends()
    ):
        self.config = config
        self.logger = logger
        self.cache = cache
    
    async def get_business_data(self, data_id: str):
        # 尝试从缓存获取
        cached_data = self.cache.get(f"business_data:{data_id}")
        if cached_data:
            return cached_data
        
        # 模拟业务逻辑
        self.logger.log_info(f"处理业务数据: {data_id}")
        await asyncio.sleep(0.1)  # 模拟处理时间
        
        data = {
            "id": data_id,
            "name": f"Business Data {data_id}",
            "processed_at": datetime.now().isoformat()
        }
        
        # 缓存结果
        self.cache.set(f"business_data:{data_id}", data)
        
        return data

@app.get("/business/{data_id}")
async def get_business_data(data_id: str, service: BusinessService = Depends()):
    return await service.get_business_data(data_id)

🔒 安全依赖

认证依赖

python
import jwt
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

security = HTTPBearer()

class User:
    def __init__(self, user_id: int, username: str, email: str, roles: list):
        self.user_id = user_id
        self.username = username
        self.email = email
        self.roles = roles
    
    def has_role(self, role: str) -> bool:
        return role in self.roles
    
    def dict(self):
        return {
            "user_id": self.user_id,
            "username": self.username,
            "email": self.email,
            "roles": self.roles
        }

SECRET_KEY = "your-secret-key"

async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
    """获取当前认证用户"""
    token = credentials.credentials
    
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
        user_id = payload.get("user_id")
        username = payload.get("username")
        email = payload.get("email")
        roles = payload.get("roles", [])
        
        if user_id is None:
            raise HTTPException(status_code=401, detail="无效的令牌")
        
        return User(user_id, username, email, roles)
    
    except jwt.ExpiredSignatureError:
        raise HTTPException(status_code=401, detail="令牌已过期")
    except jwt.InvalidTokenError:
        raise HTTPException(status_code=401, detail="无效的令牌")

async def get_admin_user(current_user: User = Depends(get_current_user)):
    """要求管理员权限的依赖"""
    if not current_user.has_role("admin"):
        raise HTTPException(status_code=403, detail="需要管理员权限")
    return current_user

def require_roles(*required_roles):
    """要求特定角色的依赖工厂"""
    async def role_checker(current_user: User = Depends(get_current_user)):
        if not any(current_user.has_role(role) for role in required_roles):
            raise HTTPException(
                status_code=403,
                detail=f"需要以下角色之一: {', '.join(required_roles)}"
            )
        return current_user
    return role_checker

# 使用认证依赖
@app.get("/profile/")
async def get_profile(current_user: User = Depends(get_current_user)):
    return {"profile": current_user.dict()}

@app.get("/admin/users/")
async def list_all_users(admin_user: User = Depends(get_admin_user)):
    return {"users": [], "admin": admin_user.username}

@app.post("/admin/posts/")
async def create_post(
    post_data: dict,
    editor: User = Depends(require_roles("admin", "editor"))
):
    return {"message": "帖子创建成功", "created_by": editor.username}

权限依赖

python
from enum import Enum

class Permission(str, Enum):
    READ_POSTS = "read_posts"
    WRITE_POSTS = "write_posts"
    DELETE_POSTS = "delete_posts"
    MANAGE_USERS = "manage_users"
    SYSTEM_ADMIN = "system_admin"

class PermissionChecker:
    def __init__(self, required_permission: Permission):
        self.required_permission = required_permission
    
    async def __call__(self, current_user: User = Depends(get_current_user)):
        # 系统管理员拥有所有权限
        if current_user.has_role("system_admin"):
            return current_user
        
        # 检查具体权限
        user_permissions = self._get_user_permissions(current_user)
        
        if self.required_permission not in user_permissions:
            raise HTTPException(
                status_code=403,
                detail=f"需要权限: {self.required_permission.value}"
            )
        
        return current_user
    
    def _get_user_permissions(self, user: User) -> list[Permission]:
        """根据用户角色获取权限列表"""
        role_permissions = {
            "admin": [Permission.READ_POSTS, Permission.WRITE_POSTS, Permission.DELETE_POSTS, Permission.MANAGE_USERS],
            "editor": [Permission.READ_POSTS, Permission.WRITE_POSTS],
            "user": [Permission.READ_POSTS]
        }
        
        permissions = []
        for role in user.roles:
            permissions.extend(role_permissions.get(role, []))
        
        return list(set(permissions))  # 去重

# 使用权限依赖
@app.get("/posts/")
async def list_posts(user: User = Depends(PermissionChecker(Permission.READ_POSTS))):
    return {"posts": [], "user": user.username}

@app.post("/posts/")
async def create_post(
    post_data: dict,
    user: User = Depends(PermissionChecker(Permission.WRITE_POSTS))
):
    return {"message": "帖子创建成功", "author": user.username}

@app.delete("/posts/{post_id}")
async def delete_post(
    post_id: int,
    user: User = Depends(PermissionChecker(Permission.DELETE_POSTS))
):
    return {"message": f"帖子 {post_id} 删除成功", "deleted_by": user.username}

🎛️ 依赖提供器和作用域

单例依赖

python
class SingletonService:
    _instance = None
    _initialized = False
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance
    
    def __init__(self):
        if not self._initialized:
            self.data = {}
            self.counter = 0
            self._initialized = True
    
    def increment(self):
        self.counter += 1
        return self.counter
    
    def set_data(self, key: str, value):
        self.data[key] = value
    
    def get_data(self, key: str):
        return self.data.get(key)

def get_singleton_service():
    """单例服务提供器"""
    return SingletonService()

@app.get("/singleton/increment/")
async def increment_counter(service: SingletonService = Depends(get_singleton_service)):
    count = service.increment()
    return {"count": count}

@app.post("/singleton/data/{key}")
async def set_singleton_data(
    key: str,
    value: str,
    service: SingletonService = Depends(get_singleton_service)
):
    service.set_data(key, value)
    return {"message": f"设置 {key} = {value}"}

@app.get("/singleton/data/{key}")
async def get_singleton_data(
    key: str,
    service: SingletonService = Depends(get_singleton_service)
):
    value = service.get_data(key)
    return {"key": key, "value": value}

作用域依赖

python
from contextvars import ContextVar

# 请求作用域存储
request_id_var: ContextVar[str] = ContextVar('request_id', default='')

class RequestScopedService:
    def __init__(self):
        self.request_id = request_id_var.get()
        self.created_at = datetime.now()
        self.request_data = {}
    
    def add_data(self, key: str, value):
        self.request_data[key] = value
    
    def get_summary(self):
        return {
            "request_id": self.request_id,
            "created_at": self.created_at.isoformat(),
            "data_count": len(self.request_data),
            "data": self.request_data
        }

# 请求级别的依赖缓存
request_services = {}

def get_request_scoped_service(request: Request):
    """请求作用域服务提供器"""
    request_id = getattr(request.state, 'request_id', 'unknown')
    
    if request_id not in request_services:
        request_id_var.set(request_id)
        request_services[request_id] = RequestScopedService()
    
    return request_services[request_id]

@app.middleware("http")
async def add_request_id(request: Request, call_next):
    request_id = str(uuid.uuid4())
    request.state.request_id = request_id
    request_id_var.set(request_id)
    
    response = await call_next(request)
    
    # 清理请求作用域的服务
    if request_id in request_services:
        del request_services[request_id]
    
    return response

@app.post("/request-scoped/data/")
async def add_request_data(
    key: str,
    value: str,
    service: RequestScopedService = Depends(get_request_scoped_service)
):
    service.add_data(key, value)
    return {"message": f"添加数据: {key} = {value}"}

@app.get("/request-scoped/summary/")
async def get_request_summary(
    service: RequestScopedService = Depends(get_request_scoped_service)
):
    return service.get_summary()

🧪 依赖测试

依赖覆盖

python
from fastapi.testclient import TestClient

# 测试用的依赖
class MockDatabaseConnection:
    def __init__(self):
        self.connected = True
    
    async def execute_query(self, query: str):
        return f"Mock result for: {query}"

def get_mock_db():
    return MockDatabaseConnection()

# 测试
def test_with_dependency_override():
    # 覆盖依赖
    app.dependency_overrides[DatabaseConnection] = get_mock_db
    
    with TestClient(app) as client:
        response = client.get("/users/1")
        assert response.status_code == 200
        assert "Mock result" in response.json()["query_result"]
    
    # 清理覆盖
    app.dependency_overrides = {}

# 测试特定依赖
def test_user_service():
    mock_db = MockDatabaseConnection()
    user_service = UserService(db=mock_db)
    
    # 测试服务逻辑
    result = await user_service.get_user_by_id(1)
    assert result["user_id"] == 1
    assert "Mock result" in result["query_result"]

总结

本章详细介绍了FastAPI的依赖注入系统:

  • 基础依赖:函数依赖、参数依赖、异步依赖
  • 类依赖:服务类、配置类、多层依赖
  • 安全依赖:认证、授权、权限检查
  • 作用域管理:单例、请求作用域、依赖缓存
  • 测试支持:依赖覆盖、模拟依赖

依赖注入是FastAPI的强大特性,它使代码更加模块化、可测试和可维护。

依赖注入最佳实践

  • 保持依赖的职责单一
  • 合理设计依赖的生命周期
  • 使用类型提示增强代码可读性
  • 善用依赖覆盖进行测试
  • 避免循环依赖
  • 考虑依赖的性能影响

在下一章中,我们将学习FastAPI的异常处理机制,了解如何优雅地处理各种错误情况。

本站内容仅供学习和研究使用。