FastAPI 中间件
概述
中间件是Web应用中处理横切关注点的重要机制,如日志记录、性能监控、安全检查、CORS处理等。FastAPI基于Starlette,提供了强大的中间件系统。本章将详细介绍如何使用和创建中间件。
🔧 中间件基础概念
中间件执行流程
mermaid
graph TB
A[客户端请求] --> B[中间件1 - 请求处理]
B --> C[中间件2 - 请求处理]
C --> D[路由处理函数]
D --> E[中间件2 - 响应处理]
E --> F[中间件1 - 响应处理]
F --> G[客户端响应]中间件采用"洋葱模型",按照添加的顺序执行请求处理,然后按相反顺序执行响应处理。
基础中间件示例
python
from fastapi import FastAPI, Request, Response
import time
import logging
app = FastAPI()
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.middleware("http")
async def log_requests(request: Request, call_next):
start_time = time.time()
# 请求前处理
logger.info(f"请求开始: {request.method} {request.url}")
# 调用下一个中间件或路由处理函数
response = await call_next(request)
# 响应后处理
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
logger.info(f"请求完成: {request.method} {request.url} - {response.status_code} - {process_time:.4f}s")
return response
@app.get("/")
async def read_root():
return {"message": "Hello World"}
@app.get("/slow")
async def slow_endpoint():
await asyncio.sleep(2) # 模拟慢操作
return {"message": "This was slow"}🛡️ 内置中间件
CORS中间件
python
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "https://myapp.com"], # 允许的源
allow_credentials=True, # 允许凭据
allow_methods=["GET", "POST", "PUT", "DELETE"], # 允许的方法
allow_headers=["*"], # 允许的头部
expose_headers=["X-Custom-Header"], # 暴露的头部
max_age=3600, # 预检请求缓存时间
)
# 开发环境配置(允许所有源)
if os.getenv("ENVIRONMENT") == "development":
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)HTTPS重定向中间件
python
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
# 生产环境强制HTTPS
if os.getenv("ENVIRONMENT") == "production":
app.add_middleware(HTTPSRedirectMiddleware)可信主机中间件
python
from fastapi.middleware.trustedhost import TrustedHostMiddleware
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["example.com", "*.example.com", "localhost", "127.0.0.1"]
)GZip压缩中间件
python
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware, minimum_size=1000)🔨 自定义中间件
请求ID中间件
python
import uuid
from contextvars import ContextVar
# 使用ContextVar存储请求ID
request_id_contextvar: ContextVar[str] = ContextVar('request_id', default="")
class RequestIDMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# 生成请求ID
request_id = str(uuid.uuid4())
request_id_contextvar.set(request_id)
# 添加到响应头
async def send_wrapper(message):
if message["type"] == "http.response.start":
headers = list(message.get("headers", []))
headers.append([b"x-request-id", request_id.encode()])
message["headers"] = headers
await send(message)
await self.app(scope, receive, send_wrapper)
else:
await self.app(scope, receive, send)
app.add_middleware(RequestIDMiddleware)
def get_request_id() -> str:
"""获取当前请求的ID"""
return request_id_contextvar.get()
@app.get("/test")
async def test_endpoint():
return {"request_id": get_request_id(), "message": "测试端点"}性能监控中间件
python
import psutil
from typing import Dict, Any
class PerformanceMiddleware:
def __init__(self, app):
self.app = app
self.stats = {
"total_requests": 0,
"total_time": 0,
"avg_response_time": 0,
"request_count_by_method": {},
"request_count_by_status": {}
}
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
request_method = scope["method"]
status_code = None
async def send_wrapper(message):
nonlocal status_code
if message["type"] == "http.response.start":
status_code = message["status"]
# 添加性能头部
headers = list(message.get("headers", []))
# 系统资源使用情况
cpu_percent = psutil.cpu_percent()
memory_percent = psutil.virtual_memory().percent
headers.extend([
[b"x-cpu-usage", f"{cpu_percent:.1f}%".encode()],
[b"x-memory-usage", f"{memory_percent:.1f}%".encode()],
[b"x-request-count", str(self.stats["total_requests"]).encode()]
])
message["headers"] = headers
await send(message)
await self.app(scope, receive, send_wrapper)
# 更新统计信息
end_time = time.time()
response_time = end_time - start_time
self.stats["total_requests"] += 1
self.stats["total_time"] += response_time
self.stats["avg_response_time"] = self.stats["total_time"] / self.stats["total_requests"]
# 按方法统计
self.stats["request_count_by_method"][request_method] = \
self.stats["request_count_by_method"].get(request_method, 0) + 1
# 按状态码统计
if status_code:
self.stats["request_count_by_status"][status_code] = \
self.stats["request_count_by_status"].get(status_code, 0) + 1
performance_middleware = PerformanceMiddleware(app)
app.add_middleware(lambda: performance_middleware)
@app.get("/stats")
async def get_performance_stats():
return performance_middleware.stats限流中间件
python
import asyncio
from collections import defaultdict
from datetime import datetime, timedelta
class RateLimitMiddleware:
def __init__(self, app, calls: int = 100, period: int = 60):
self.app = app
self.calls = calls # 允许的调用次数
self.period = period # 时间窗口(秒)
self.clients = defaultdict(list) # 存储客户端请求时间
self.cleanup_task = None
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
# 获取客户端IP
client_ip = self._get_client_ip(scope)
current_time = datetime.now()
# 清理过期记录
self._cleanup_old_requests(client_ip, current_time)
# 检查限流
if len(self.clients[client_ip]) >= self.calls:
# 超过限制,返回429
response = {
"type": "http.response.start",
"status": 429,
"headers": [
[b"content-type", b"application/json"],
[b"x-ratelimit-limit", str(self.calls).encode()],
[b"x-ratelimit-remaining", b"0"],
[b"x-ratelimit-reset", str(int((current_time + timedelta(seconds=self.period)).timestamp())).encode()]
]
}
await send(response)
body = {
"type": "http.response.body",
"body": b'{"error": "Rate limit exceeded", "message": "Too many requests"}'
}
await send(body)
return
# 记录请求时间
self.clients[client_ip].append(current_time)
# 添加限流头部
remaining = self.calls - len(self.clients[client_ip])
async def send_wrapper(message):
if message["type"] == "http.response.start":
headers = list(message.get("headers", []))
headers.extend([
[b"x-ratelimit-limit", str(self.calls).encode()],
[b"x-ratelimit-remaining", str(remaining).encode()],
[b"x-ratelimit-reset", str(int((current_time + timedelta(seconds=self.period)).timestamp())).encode()]
])
message["headers"] = headers
await send(message)
await self.app(scope, receive, send_wrapper)
def _get_client_ip(self, scope):
# 尝试从headers中获取真实IP
headers = dict(scope.get("headers", []))
# 检查常见的代理头部
for header in [b"x-forwarded-for", b"x-real-ip", b"cf-connecting-ip"]:
if header in headers:
ip = headers[header].decode().split(",")[0].strip()
if ip:
return ip
# 使用直接连接的IP
client = scope.get("client")
return client[0] if client else "unknown"
def _cleanup_old_requests(self, client_ip, current_time):
cutoff_time = current_time - timedelta(seconds=self.period)
self.clients[client_ip] = [
req_time for req_time in self.clients[client_ip]
if req_time > cutoff_time
]
# 如果客户端没有活跃请求,删除记录
if not self.clients[client_ip]:
del self.clients[client_ip]
# 应用限流中间件:每分钟最多100个请求
app.add_middleware(lambda: RateLimitMiddleware(app, calls=100, period=60))认证中间件
python
import jwt
from fastapi import HTTPException, status
class AuthenticationMiddleware:
def __init__(self, app, secret_key: str, excluded_paths: list = None):
self.app = app
self.secret_key = secret_key
self.excluded_paths = excluded_paths or ["/", "/docs", "/redoc", "/openapi.json", "/login"]
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
path = scope["path"]
method = scope["method"]
# 检查是否为排除路径
if self._is_excluded_path(path):
await self.app(scope, receive, send)
return
# 获取Authorization头部
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization")
if not auth_header:
await self._send_unauthorized(send, "Missing authorization header")
return
try:
# 解析Bearer token
auth_value = auth_header.decode()
if not auth_value.startswith("Bearer "):
await self._send_unauthorized(send, "Invalid authorization format")
return
token = auth_value[7:] # 移除"Bearer "前缀
# 验证JWT token
payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
# 将用户信息添加到scope中
scope["user"] = payload
except jwt.ExpiredSignatureError:
await self._send_unauthorized(send, "Token has expired")
return
except jwt.InvalidTokenError:
await self._send_unauthorized(send, "Invalid token")
return
await self.app(scope, receive, send)
def _is_excluded_path(self, path: str) -> bool:
return any(path.startswith(excluded) for excluded in self.excluded_paths)
async def _send_unauthorized(self, send, message: str):
response = {
"type": "http.response.start",
"status": 401,
"headers": [[b"content-type", b"application/json"]]
}
await send(response)
body = {
"type": "http.response.body",
"body": f'{{"error": "Unauthorized", "message": "{message}"}}'.encode()
}
await send(body)
# 应用认证中间件
SECRET_KEY = "your-secret-key"
app.add_middleware(lambda: AuthenticationMiddleware(app, SECRET_KEY))
# 登录端点(不需要认证)
@app.post("/login")
async def login(username: str, password: str):
# 简单的用户验证
if username == "admin" and password == "secret":
payload = {
"user_id": 1,
"username": username,
"exp": datetime.utcnow() + timedelta(hours=24)
}
token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")
return {"access_token": token, "token_type": "bearer"}
else:
raise HTTPException(status_code=401, detail="Invalid credentials")
# 受保护的端点
@app.get("/protected")
async def protected_endpoint(request: Request):
user = request.scope.get("user")
return {"message": "This is protected", "user": user}📊 错误处理中间件
全局异常处理中间件
python
import traceback
from fastapi import Request
from fastapi.responses import JSONResponse
class ErrorHandlingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
try:
await self.app(scope, receive, send)
except Exception as exc:
request = Request(scope, receive)
await self._handle_error(exc, request, send)
async def _handle_error(self, exc: Exception, request: Request, send):
# 记录错误
logger.error(f"Unhandled exception: {exc}", exc_info=True)
# 根据异常类型确定响应
if isinstance(exc, HTTPException):
status_code = exc.status_code
detail = exc.detail
elif isinstance(exc, ValueError):
status_code = 400
detail = "Invalid input data"
elif isinstance(exc, FileNotFoundError):
status_code = 404
detail = "Resource not found"
else:
status_code = 500
detail = "Internal server error"
# 构建错误响应
error_response = {
"error": {
"type": exc.__class__.__name__,
"message": detail,
"path": str(request.url.path),
"method": request.method,
"timestamp": datetime.now().isoformat()
}
}
# 开发环境包含详细错误信息
if os.getenv("ENVIRONMENT") == "development":
error_response["error"]["traceback"] = traceback.format_exc()
response = JSONResponse(
content=error_response,
status_code=status_code
)
await response(scope, receive, send)
app.add_middleware(ErrorHandlingMiddleware)🔧 中间件最佳实践
中间件工厂
python
def create_logging_middleware(log_level: str = "INFO"):
"""创建日志中间件的工厂函数"""
def logging_middleware(app):
async def middleware(scope, receive, send):
if scope["type"] == "http":
start_time = time.time()
request = Request(scope, receive)
logger.log(
getattr(logging, log_level),
f"Request: {request.method} {request.url}"
)
async def send_wrapper(message):
if message["type"] == "http.response.start":
process_time = time.time() - start_time
logger.log(
getattr(logging, log_level),
f"Response: {message['status']} - {process_time:.4f}s"
)
await send(message)
await app(scope, receive, send_wrapper)
else:
await app(scope, receive, send)
return middleware
return logging_middleware
# 使用工厂创建中间件
app.add_middleware(create_logging_middleware("DEBUG"))条件中间件
python
def conditional_middleware(condition_func):
"""条件中间件装饰器"""
def decorator(middleware_class):
def wrapper(app):
if condition_func():
return middleware_class(app)
else:
# 如果条件不满足,返回直通中间件
async def passthrough(scope, receive, send):
await app(scope, receive, send)
return passthrough
return wrapper
return decorator
@conditional_middleware(lambda: os.getenv("ENABLE_PROFILING") == "true")
class ProfilingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
# 性能分析逻辑
import cProfile
profiler = cProfile.Profile()
profiler.enable()
await self.app(scope, receive, send)
profiler.disable()
# 保存或输出分析结果
profiler.dump_stats(f"/tmp/profile_{time.time()}.prof")
app.add_middleware(ProfilingMiddleware)中间件配置
python
from pydantic import BaseSettings
class MiddlewareSettings(BaseSettings):
enable_cors: bool = True
cors_origins: list = ["*"]
enable_gzip: bool = True
gzip_minimum_size: int = 1000
enable_rate_limiting: bool = False
rate_limit_calls: int = 100
rate_limit_period: int = 60
class Config:
env_file = ".env"
settings = MiddlewareSettings()
def setup_middleware(app: FastAPI):
"""配置所有中间件"""
# CORS
if settings.enable_cors:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# GZip压缩
if settings.enable_gzip:
app.add_middleware(
GZipMiddleware,
minimum_size=settings.gzip_minimum_size
)
# 限流
if settings.enable_rate_limiting:
app.add_middleware(
lambda: RateLimitMiddleware(
app,
calls=settings.rate_limit_calls,
period=settings.rate_limit_period
)
)
# 总是添加的中间件
app.add_middleware(RequestIDMiddleware)
app.add_middleware(ErrorHandlingMiddleware)
# 设置中间件
setup_middleware(app)总结
本章详细介绍了FastAPI的中间件系统:
- ✅ 中间件基础:执行流程、基本概念
- ✅ 内置中间件:CORS、HTTPS重定向、GZip等
- ✅ 自定义中间件:请求ID、性能监控、限流、认证
- ✅ 错误处理:全局异常处理中间件
- ✅ 最佳实践:中间件工厂、条件中间件、配置管理
中间件是处理横切关注点的强大工具,合理使用可以大大提升应用的功能性和可维护性。
中间件设计建议
- 保持中间件职责单一
- 考虑中间件的执行顺序
- 提供配置开关控制中间件启用
- 注意性能影响,避免阻塞操作
- 做好错误处理和日志记录
- 编写可测试的中间件代码
在下一章中,我们将学习FastAPI的依赖注入系统,了解如何管理应用的依赖关系。