FastAPI 依赖注入
概述
依赖注入(Dependency Injection)是FastAPI的核心特性之一,它提供了一种优雅的方式来管理应用程序的依赖关系。通过依赖注入,我们可以实现代码解耦、提高可测试性、增强可维护性。本章将深入探讨FastAPI的依赖注入系统。
🔧 依赖注入基础
简单依赖
python
from fastapi import FastAPI, Depends
from typing import Optional
app = FastAPI()
def get_current_time():
"""简单的依赖函数"""
from datetime import datetime
return datetime.now()
def get_user_agent(user_agent: Optional[str] = Header(None)):
"""从请求头获取User-Agent"""
return user_agent or "Unknown"
@app.get("/info/")
async def get_info(
current_time: datetime = Depends(get_current_time),
user_agent: str = Depends(get_user_agent)
):
return {
"current_time": current_time.isoformat(),
"user_agent": user_agent,
"message": "依赖注入示例"
}
# 依赖也可以是异步的
async def get_database_info():
"""模拟异步数据库查询"""
await asyncio.sleep(0.1) # 模拟数据库延迟
return {"database": "PostgreSQL", "version": "13.7", "status": "connected"}
@app.get("/db-info/")
async def get_db_info(db_info: dict = Depends(get_database_info)):
return db_info带参数的依赖
python
def get_query_extractor(query_param: str = "q", default_value: str = ""):
"""创建查询参数提取器的工厂函数"""
def query_extractor(request: Request):
return request.query_params.get(query_param, default_value)
return query_extractor
def get_pagination_params(page: int = Query(1, ge=1), size: int = Query(10, ge=1, le=100)):
"""分页参数依赖"""
return {
"page": page,
"size": size,
"offset": (page - 1) * size
}
def get_sorting_params(
sort_by: str = Query("created_at", description="排序字段"),
sort_order: str = Query("desc", regex="^(asc|desc)$", description="排序方向")
):
"""排序参数依赖"""
return {"sort_by": sort_by, "sort_order": sort_order}
@app.get("/posts/")
async def list_posts(
pagination: dict = Depends(get_pagination_params),
sorting: dict = Depends(get_sorting_params),
search_query: str = Depends(get_query_extractor("search", ""))
):
return {
"pagination": pagination,
"sorting": sorting,
"search_query": search_query,
"posts": [] # 模拟帖子列表
}🏗️ 类作为依赖
简单类依赖
python
class DatabaseConnection:
def __init__(self):
self.host = "localhost"
self.port = 5432
self.database = "myapp"
self.connected = False
async def connect(self):
"""模拟数据库连接"""
await asyncio.sleep(0.1)
self.connected = True
return self
async def disconnect(self):
"""模拟断开连接"""
self.connected = False
async def execute_query(self, query: str):
"""模拟执行查询"""
if not self.connected:
raise Exception("数据库未连接")
await asyncio.sleep(0.05)
return f"执行查询: {query}"
class UserService:
def __init__(self, db: DatabaseConnection = Depends()):
self.db = db
async def get_user_by_id(self, user_id: int):
"""根据ID获取用户"""
await self.db.connect()
result = await self.db.execute_query(f"SELECT * FROM users WHERE id = {user_id}")
await self.db.disconnect()
return {"user_id": user_id, "name": f"User {user_id}", "query_result": result}
async def create_user(self, user_data: dict):
"""创建用户"""
await self.db.connect()
result = await self.db.execute_query(f"INSERT INTO users VALUES (...)")
await self.db.disconnect()
return {"message": "用户创建成功", "user_data": user_data}
@app.get("/users/{user_id}")
async def get_user(user_id: int, user_service: UserService = Depends()):
return await user_service.get_user_by_id(user_id)
@app.post("/users/")
async def create_user(user_data: dict, user_service: UserService = Depends()):
return await user_service.create_user(user_data)具有配置的类依赖
python
from pydantic import BaseSettings
import httpx
class APISettings(BaseSettings):
external_api_url: str = "https://api.example.com"
api_key: str = "default-key"
timeout: int = 30
max_retries: int = 3
class Config:
env_file = ".env"
class ExternalAPIClient:
def __init__(self, settings: APISettings = Depends()):
self.settings = settings
self.client = None
async def __aenter__(self):
"""异步上下文管理器进入"""
self.client = httpx.AsyncClient(
base_url=self.settings.external_api_url,
timeout=self.settings.timeout,
headers={"Authorization": f"Bearer {self.settings.api_key}"}
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器退出"""
if self.client:
await self.client.aclose()
async def get_data(self, endpoint: str):
"""获取外部API数据"""
if not self.client:
raise Exception("客户端未初始化")
for attempt in range(self.settings.max_retries):
try:
response = await self.client.get(endpoint)
response.raise_for_status()
return response.json()
except httpx.RequestError as e:
if attempt == self.settings.max_retries - 1:
raise
await asyncio.sleep(2 ** attempt) # 指数退避
async def get_api_client():
"""API客户端依赖提供器"""
settings = APISettings()
async with ExternalAPIClient(settings) as client:
yield client
@app.get("/external-data/{endpoint}")
async def get_external_data(
endpoint: str,
api_client: ExternalAPIClient = Depends(get_api_client)
):
try:
data = await api_client.get_data(endpoint)
return {"data": data, "source": "external_api"}
except Exception as e:
raise HTTPException(status_code=503, detail=f"外部API调用失败: {str(e)}")🔄 依赖的子依赖
多层依赖
python
class ConfigService:
def __init__(self):
self.config = {
"app_name": "MyApp",
"version": "1.0.0",
"debug": True,
"database_url": "postgresql://localhost/myapp"
}
def get(self, key: str, default=None):
return self.config.get(key, default)
class LoggingService:
def __init__(self, config: ConfigService = Depends()):
self.config = config
self.app_name = config.get("app_name")
self.debug = config.get("debug")
def log_info(self, message: str):
level = "DEBUG" if self.debug else "INFO"
print(f"[{level}] {self.app_name}: {message}")
def log_error(self, message: str):
print(f"[ERROR] {self.app_name}: {message}")
class CacheService:
def __init__(self, config: ConfigService = Depends(), logger: LoggingService = Depends()):
self.config = config
self.logger = logger
self.cache = {}
self.logger.log_info("缓存服务初始化")
def get(self, key: str):
value = self.cache.get(key)
if value:
self.logger.log_info(f"缓存命中: {key}")
else:
self.logger.log_info(f"缓存未命中: {key}")
return value
def set(self, key: str, value, ttl: int = 300):
self.cache[key] = value
self.logger.log_info(f"缓存设置: {key}")
class BusinessService:
def __init__(
self,
config: ConfigService = Depends(),
logger: LoggingService = Depends(),
cache: CacheService = Depends()
):
self.config = config
self.logger = logger
self.cache = cache
async def get_business_data(self, data_id: str):
# 尝试从缓存获取
cached_data = self.cache.get(f"business_data:{data_id}")
if cached_data:
return cached_data
# 模拟业务逻辑
self.logger.log_info(f"处理业务数据: {data_id}")
await asyncio.sleep(0.1) # 模拟处理时间
data = {
"id": data_id,
"name": f"Business Data {data_id}",
"processed_at": datetime.now().isoformat()
}
# 缓存结果
self.cache.set(f"business_data:{data_id}", data)
return data
@app.get("/business/{data_id}")
async def get_business_data(data_id: str, service: BusinessService = Depends()):
return await service.get_business_data(data_id)🔒 安全依赖
认证依赖
python
import jwt
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
security = HTTPBearer()
class User:
def __init__(self, user_id: int, username: str, email: str, roles: list):
self.user_id = user_id
self.username = username
self.email = email
self.roles = roles
def has_role(self, role: str) -> bool:
return role in self.roles
def dict(self):
return {
"user_id": self.user_id,
"username": self.username,
"email": self.email,
"roles": self.roles
}
SECRET_KEY = "your-secret-key"
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取当前认证用户"""
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
user_id = payload.get("user_id")
username = payload.get("username")
email = payload.get("email")
roles = payload.get("roles", [])
if user_id is None:
raise HTTPException(status_code=401, detail="无效的令牌")
return User(user_id, username, email, roles)
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="令牌已过期")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="无效的令牌")
async def get_admin_user(current_user: User = Depends(get_current_user)):
"""要求管理员权限的依赖"""
if not current_user.has_role("admin"):
raise HTTPException(status_code=403, detail="需要管理员权限")
return current_user
def require_roles(*required_roles):
"""要求特定角色的依赖工厂"""
async def role_checker(current_user: User = Depends(get_current_user)):
if not any(current_user.has_role(role) for role in required_roles):
raise HTTPException(
status_code=403,
detail=f"需要以下角色之一: {', '.join(required_roles)}"
)
return current_user
return role_checker
# 使用认证依赖
@app.get("/profile/")
async def get_profile(current_user: User = Depends(get_current_user)):
return {"profile": current_user.dict()}
@app.get("/admin/users/")
async def list_all_users(admin_user: User = Depends(get_admin_user)):
return {"users": [], "admin": admin_user.username}
@app.post("/admin/posts/")
async def create_post(
post_data: dict,
editor: User = Depends(require_roles("admin", "editor"))
):
return {"message": "帖子创建成功", "created_by": editor.username}权限依赖
python
from enum import Enum
class Permission(str, Enum):
READ_POSTS = "read_posts"
WRITE_POSTS = "write_posts"
DELETE_POSTS = "delete_posts"
MANAGE_USERS = "manage_users"
SYSTEM_ADMIN = "system_admin"
class PermissionChecker:
def __init__(self, required_permission: Permission):
self.required_permission = required_permission
async def __call__(self, current_user: User = Depends(get_current_user)):
# 系统管理员拥有所有权限
if current_user.has_role("system_admin"):
return current_user
# 检查具体权限
user_permissions = self._get_user_permissions(current_user)
if self.required_permission not in user_permissions:
raise HTTPException(
status_code=403,
detail=f"需要权限: {self.required_permission.value}"
)
return current_user
def _get_user_permissions(self, user: User) -> list[Permission]:
"""根据用户角色获取权限列表"""
role_permissions = {
"admin": [Permission.READ_POSTS, Permission.WRITE_POSTS, Permission.DELETE_POSTS, Permission.MANAGE_USERS],
"editor": [Permission.READ_POSTS, Permission.WRITE_POSTS],
"user": [Permission.READ_POSTS]
}
permissions = []
for role in user.roles:
permissions.extend(role_permissions.get(role, []))
return list(set(permissions)) # 去重
# 使用权限依赖
@app.get("/posts/")
async def list_posts(user: User = Depends(PermissionChecker(Permission.READ_POSTS))):
return {"posts": [], "user": user.username}
@app.post("/posts/")
async def create_post(
post_data: dict,
user: User = Depends(PermissionChecker(Permission.WRITE_POSTS))
):
return {"message": "帖子创建成功", "author": user.username}
@app.delete("/posts/{post_id}")
async def delete_post(
post_id: int,
user: User = Depends(PermissionChecker(Permission.DELETE_POSTS))
):
return {"message": f"帖子 {post_id} 删除成功", "deleted_by": user.username}🎛️ 依赖提供器和作用域
单例依赖
python
class SingletonService:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.data = {}
self.counter = 0
self._initialized = True
def increment(self):
self.counter += 1
return self.counter
def set_data(self, key: str, value):
self.data[key] = value
def get_data(self, key: str):
return self.data.get(key)
def get_singleton_service():
"""单例服务提供器"""
return SingletonService()
@app.get("/singleton/increment/")
async def increment_counter(service: SingletonService = Depends(get_singleton_service)):
count = service.increment()
return {"count": count}
@app.post("/singleton/data/{key}")
async def set_singleton_data(
key: str,
value: str,
service: SingletonService = Depends(get_singleton_service)
):
service.set_data(key, value)
return {"message": f"设置 {key} = {value}"}
@app.get("/singleton/data/{key}")
async def get_singleton_data(
key: str,
service: SingletonService = Depends(get_singleton_service)
):
value = service.get_data(key)
return {"key": key, "value": value}作用域依赖
python
from contextvars import ContextVar
# 请求作用域存储
request_id_var: ContextVar[str] = ContextVar('request_id', default='')
class RequestScopedService:
def __init__(self):
self.request_id = request_id_var.get()
self.created_at = datetime.now()
self.request_data = {}
def add_data(self, key: str, value):
self.request_data[key] = value
def get_summary(self):
return {
"request_id": self.request_id,
"created_at": self.created_at.isoformat(),
"data_count": len(self.request_data),
"data": self.request_data
}
# 请求级别的依赖缓存
request_services = {}
def get_request_scoped_service(request: Request):
"""请求作用域服务提供器"""
request_id = getattr(request.state, 'request_id', 'unknown')
if request_id not in request_services:
request_id_var.set(request_id)
request_services[request_id] = RequestScopedService()
return request_services[request_id]
@app.middleware("http")
async def add_request_id(request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
request_id_var.set(request_id)
response = await call_next(request)
# 清理请求作用域的服务
if request_id in request_services:
del request_services[request_id]
return response
@app.post("/request-scoped/data/")
async def add_request_data(
key: str,
value: str,
service: RequestScopedService = Depends(get_request_scoped_service)
):
service.add_data(key, value)
return {"message": f"添加数据: {key} = {value}"}
@app.get("/request-scoped/summary/")
async def get_request_summary(
service: RequestScopedService = Depends(get_request_scoped_service)
):
return service.get_summary()🧪 依赖测试
依赖覆盖
python
from fastapi.testclient import TestClient
# 测试用的依赖
class MockDatabaseConnection:
def __init__(self):
self.connected = True
async def execute_query(self, query: str):
return f"Mock result for: {query}"
def get_mock_db():
return MockDatabaseConnection()
# 测试
def test_with_dependency_override():
# 覆盖依赖
app.dependency_overrides[DatabaseConnection] = get_mock_db
with TestClient(app) as client:
response = client.get("/users/1")
assert response.status_code == 200
assert "Mock result" in response.json()["query_result"]
# 清理覆盖
app.dependency_overrides = {}
# 测试特定依赖
def test_user_service():
mock_db = MockDatabaseConnection()
user_service = UserService(db=mock_db)
# 测试服务逻辑
result = await user_service.get_user_by_id(1)
assert result["user_id"] == 1
assert "Mock result" in result["query_result"]总结
本章详细介绍了FastAPI的依赖注入系统:
- ✅ 基础依赖:函数依赖、参数依赖、异步依赖
- ✅ 类依赖:服务类、配置类、多层依赖
- ✅ 安全依赖:认证、授权、权限检查
- ✅ 作用域管理:单例、请求作用域、依赖缓存
- ✅ 测试支持:依赖覆盖、模拟依赖
依赖注入是FastAPI的强大特性,它使代码更加模块化、可测试和可维护。
依赖注入最佳实践
- 保持依赖的职责单一
- 合理设计依赖的生命周期
- 使用类型提示增强代码可读性
- 善用依赖覆盖进行测试
- 避免循环依赖
- 考虑依赖的性能影响
在下一章中,我们将学习FastAPI的异常处理机制,了解如何优雅地处理各种错误情况。