Skip to content

FastAPI 中间件

概述

中间件是Web应用中处理横切关注点的重要机制,如日志记录、性能监控、安全检查、CORS处理等。FastAPI基于Starlette,提供了强大的中间件系统。本章将详细介绍如何使用和创建中间件。

🔧 中间件基础概念

中间件执行流程

mermaid
graph TB
    A[客户端请求] --> B[中间件1 - 请求处理]
    B --> C[中间件2 - 请求处理]
    C --> D[路由处理函数]
    D --> E[中间件2 - 响应处理]
    E --> F[中间件1 - 响应处理]
    F --> G[客户端响应]

中间件采用"洋葱模型",按照添加的顺序执行请求处理,然后按相反顺序执行响应处理。

基础中间件示例

python
from fastapi import FastAPI, Request, Response
import time
import logging

app = FastAPI()

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@app.middleware("http")
async def log_requests(request: Request, call_next):
    start_time = time.time()
    
    # 请求前处理
    logger.info(f"请求开始: {request.method} {request.url}")
    
    # 调用下一个中间件或路由处理函数
    response = await call_next(request)
    
    # 响应后处理
    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = str(process_time)
    logger.info(f"请求完成: {request.method} {request.url} - {response.status_code} - {process_time:.4f}s")
    
    return response

@app.get("/")
async def read_root():
    return {"message": "Hello World"}

@app.get("/slow")
async def slow_endpoint():
    await asyncio.sleep(2)  # 模拟慢操作
    return {"message": "This was slow"}

🛡️ 内置中间件

CORS中间件

python
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

# 配置CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3000", "https://myapp.com"],  # 允许的源
    allow_credentials=True,  # 允许凭据
    allow_methods=["GET", "POST", "PUT", "DELETE"],  # 允许的方法
    allow_headers=["*"],  # 允许的头部
    expose_headers=["X-Custom-Header"],  # 暴露的头部
    max_age=3600,  # 预检请求缓存时间
)

# 开发环境配置(允许所有源)
if os.getenv("ENVIRONMENT") == "development":
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

HTTPS重定向中间件

python
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware

# 生产环境强制HTTPS
if os.getenv("ENVIRONMENT") == "production":
    app.add_middleware(HTTPSRedirectMiddleware)

可信主机中间件

python
from fastapi.middleware.trustedhost import TrustedHostMiddleware

app.add_middleware(
    TrustedHostMiddleware, 
    allowed_hosts=["example.com", "*.example.com", "localhost", "127.0.0.1"]
)

GZip压缩中间件

python
from fastapi.middleware.gzip import GZipMiddleware

app.add_middleware(GZipMiddleware, minimum_size=1000)

🔨 自定义中间件

请求ID中间件

python
import uuid
from contextvars import ContextVar

# 使用ContextVar存储请求ID
request_id_contextvar: ContextVar[str] = ContextVar('request_id', default="")

class RequestIDMiddleware:
    def __init__(self, app):
        self.app = app
    
    async def __call__(self, scope, receive, send):
        if scope["type"] == "http":
            # 生成请求ID
            request_id = str(uuid.uuid4())
            request_id_contextvar.set(request_id)
            
            # 添加到响应头
            async def send_wrapper(message):
                if message["type"] == "http.response.start":
                    headers = list(message.get("headers", []))
                    headers.append([b"x-request-id", request_id.encode()])
                    message["headers"] = headers
                await send(message)
            
            await self.app(scope, receive, send_wrapper)
        else:
            await self.app(scope, receive, send)

app.add_middleware(RequestIDMiddleware)

def get_request_id() -> str:
    """获取当前请求的ID"""
    return request_id_contextvar.get()

@app.get("/test")
async def test_endpoint():
    return {"request_id": get_request_id(), "message": "测试端点"}

性能监控中间件

python
import psutil
from typing import Dict, Any

class PerformanceMiddleware:
    def __init__(self, app):
        self.app = app
        self.stats = {
            "total_requests": 0,
            "total_time": 0,
            "avg_response_time": 0,
            "request_count_by_method": {},
            "request_count_by_status": {}
        }
    
    async def __call__(self, scope, receive, send):
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return
        
        start_time = time.time()
        request_method = scope["method"]
        status_code = None
        
        async def send_wrapper(message):
            nonlocal status_code
            if message["type"] == "http.response.start":
                status_code = message["status"]
                
                # 添加性能头部
                headers = list(message.get("headers", []))
                
                # 系统资源使用情况
                cpu_percent = psutil.cpu_percent()
                memory_percent = psutil.virtual_memory().percent
                
                headers.extend([
                    [b"x-cpu-usage", f"{cpu_percent:.1f}%".encode()],
                    [b"x-memory-usage", f"{memory_percent:.1f}%".encode()],
                    [b"x-request-count", str(self.stats["total_requests"]).encode()]
                ])
                
                message["headers"] = headers
            
            await send(message)
        
        await self.app(scope, receive, send_wrapper)
        
        # 更新统计信息
        end_time = time.time()
        response_time = end_time - start_time
        
        self.stats["total_requests"] += 1
        self.stats["total_time"] += response_time
        self.stats["avg_response_time"] = self.stats["total_time"] / self.stats["total_requests"]
        
        # 按方法统计
        self.stats["request_count_by_method"][request_method] = \
            self.stats["request_count_by_method"].get(request_method, 0) + 1
        
        # 按状态码统计
        if status_code:
            self.stats["request_count_by_status"][status_code] = \
                self.stats["request_count_by_status"].get(status_code, 0) + 1

performance_middleware = PerformanceMiddleware(app)
app.add_middleware(lambda: performance_middleware)

@app.get("/stats")
async def get_performance_stats():
    return performance_middleware.stats

限流中间件

python
import asyncio
from collections import defaultdict
from datetime import datetime, timedelta

class RateLimitMiddleware:
    def __init__(self, app, calls: int = 100, period: int = 60):
        self.app = app
        self.calls = calls  # 允许的调用次数
        self.period = period  # 时间窗口(秒)
        self.clients = defaultdict(list)  # 存储客户端请求时间
        self.cleanup_task = None
    
    async def __call__(self, scope, receive, send):
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return
        
        # 获取客户端IP
        client_ip = self._get_client_ip(scope)
        current_time = datetime.now()
        
        # 清理过期记录
        self._cleanup_old_requests(client_ip, current_time)
        
        # 检查限流
        if len(self.clients[client_ip]) >= self.calls:
            # 超过限制,返回429
            response = {
                "type": "http.response.start",
                "status": 429,
                "headers": [
                    [b"content-type", b"application/json"],
                    [b"x-ratelimit-limit", str(self.calls).encode()],
                    [b"x-ratelimit-remaining", b"0"],
                    [b"x-ratelimit-reset", str(int((current_time + timedelta(seconds=self.period)).timestamp())).encode()]
                ]
            }
            await send(response)
            
            body = {
                "type": "http.response.body",
                "body": b'{"error": "Rate limit exceeded", "message": "Too many requests"}'
            }
            await send(body)
            return
        
        # 记录请求时间
        self.clients[client_ip].append(current_time)
        
        # 添加限流头部
        remaining = self.calls - len(self.clients[client_ip])
        
        async def send_wrapper(message):
            if message["type"] == "http.response.start":
                headers = list(message.get("headers", []))
                headers.extend([
                    [b"x-ratelimit-limit", str(self.calls).encode()],
                    [b"x-ratelimit-remaining", str(remaining).encode()],
                    [b"x-ratelimit-reset", str(int((current_time + timedelta(seconds=self.period)).timestamp())).encode()]
                ])
                message["headers"] = headers
            await send(message)
        
        await self.app(scope, receive, send_wrapper)
    
    def _get_client_ip(self, scope):
        # 尝试从headers中获取真实IP
        headers = dict(scope.get("headers", []))
        
        # 检查常见的代理头部
        for header in [b"x-forwarded-for", b"x-real-ip", b"cf-connecting-ip"]:
            if header in headers:
                ip = headers[header].decode().split(",")[0].strip()
                if ip:
                    return ip
        
        # 使用直接连接的IP
        client = scope.get("client")
        return client[0] if client else "unknown"
    
    def _cleanup_old_requests(self, client_ip, current_time):
        cutoff_time = current_time - timedelta(seconds=self.period)
        self.clients[client_ip] = [
            req_time for req_time in self.clients[client_ip] 
            if req_time > cutoff_time
        ]
        
        # 如果客户端没有活跃请求,删除记录
        if not self.clients[client_ip]:
            del self.clients[client_ip]

# 应用限流中间件:每分钟最多100个请求
app.add_middleware(lambda: RateLimitMiddleware(app, calls=100, period=60))

认证中间件

python
import jwt
from fastapi import HTTPException, status

class AuthenticationMiddleware:
    def __init__(self, app, secret_key: str, excluded_paths: list = None):
        self.app = app
        self.secret_key = secret_key
        self.excluded_paths = excluded_paths or ["/", "/docs", "/redoc", "/openapi.json", "/login"]
    
    async def __call__(self, scope, receive, send):
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return
        
        path = scope["path"]
        method = scope["method"]
        
        # 检查是否为排除路径
        if self._is_excluded_path(path):
            await self.app(scope, receive, send)
            return
        
        # 获取Authorization头部
        headers = dict(scope.get("headers", []))
        auth_header = headers.get(b"authorization")
        
        if not auth_header:
            await self._send_unauthorized(send, "Missing authorization header")
            return
        
        try:
            # 解析Bearer token
            auth_value = auth_header.decode()
            if not auth_value.startswith("Bearer "):
                await self._send_unauthorized(send, "Invalid authorization format")
                return
            
            token = auth_value[7:]  # 移除"Bearer "前缀
            
            # 验证JWT token
            payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
            
            # 将用户信息添加到scope中
            scope["user"] = payload
            
        except jwt.ExpiredSignatureError:
            await self._send_unauthorized(send, "Token has expired")
            return
        except jwt.InvalidTokenError:
            await self._send_unauthorized(send, "Invalid token")
            return
        
        await self.app(scope, receive, send)
    
    def _is_excluded_path(self, path: str) -> bool:
        return any(path.startswith(excluded) for excluded in self.excluded_paths)
    
    async def _send_unauthorized(self, send, message: str):
        response = {
            "type": "http.response.start",
            "status": 401,
            "headers": [[b"content-type", b"application/json"]]
        }
        await send(response)
        
        body = {
            "type": "http.response.body",
            "body": f'{{"error": "Unauthorized", "message": "{message}"}}'.encode()
        }
        await send(body)

# 应用认证中间件
SECRET_KEY = "your-secret-key"
app.add_middleware(lambda: AuthenticationMiddleware(app, SECRET_KEY))

# 登录端点(不需要认证)
@app.post("/login")
async def login(username: str, password: str):
    # 简单的用户验证
    if username == "admin" and password == "secret":
        payload = {
            "user_id": 1,
            "username": username,
            "exp": datetime.utcnow() + timedelta(hours=24)
        }
        token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")
        return {"access_token": token, "token_type": "bearer"}
    else:
        raise HTTPException(status_code=401, detail="Invalid credentials")

# 受保护的端点
@app.get("/protected")
async def protected_endpoint(request: Request):
    user = request.scope.get("user")
    return {"message": "This is protected", "user": user}

📊 错误处理中间件

全局异常处理中间件

python
import traceback
from fastapi import Request
from fastapi.responses import JSONResponse

class ErrorHandlingMiddleware:
    def __init__(self, app):
        self.app = app
    
    async def __call__(self, scope, receive, send):
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return
        
        try:
            await self.app(scope, receive, send)
        except Exception as exc:
            request = Request(scope, receive)
            await self._handle_error(exc, request, send)
    
    async def _handle_error(self, exc: Exception, request: Request, send):
        # 记录错误
        logger.error(f"Unhandled exception: {exc}", exc_info=True)
        
        # 根据异常类型确定响应
        if isinstance(exc, HTTPException):
            status_code = exc.status_code
            detail = exc.detail
        elif isinstance(exc, ValueError):
            status_code = 400
            detail = "Invalid input data"
        elif isinstance(exc, FileNotFoundError):
            status_code = 404
            detail = "Resource not found"
        else:
            status_code = 500
            detail = "Internal server error"
        
        # 构建错误响应
        error_response = {
            "error": {
                "type": exc.__class__.__name__,
                "message": detail,
                "path": str(request.url.path),
                "method": request.method,
                "timestamp": datetime.now().isoformat()
            }
        }
        
        # 开发环境包含详细错误信息
        if os.getenv("ENVIRONMENT") == "development":
            error_response["error"]["traceback"] = traceback.format_exc()
        
        response = JSONResponse(
            content=error_response,
            status_code=status_code
        )
        
        await response(scope, receive, send)

app.add_middleware(ErrorHandlingMiddleware)

🔧 中间件最佳实践

中间件工厂

python
def create_logging_middleware(log_level: str = "INFO"):
    """创建日志中间件的工厂函数"""
    
    def logging_middleware(app):
        async def middleware(scope, receive, send):
            if scope["type"] == "http":
                start_time = time.time()
                request = Request(scope, receive)
                
                logger.log(
                    getattr(logging, log_level),
                    f"Request: {request.method} {request.url}"
                )
                
                async def send_wrapper(message):
                    if message["type"] == "http.response.start":
                        process_time = time.time() - start_time
                        logger.log(
                            getattr(logging, log_level),
                            f"Response: {message['status']} - {process_time:.4f}s"
                        )
                    await send(message)
                
                await app(scope, receive, send_wrapper)
            else:
                await app(scope, receive, send)
        
        return middleware
    
    return logging_middleware

# 使用工厂创建中间件
app.add_middleware(create_logging_middleware("DEBUG"))

条件中间件

python
def conditional_middleware(condition_func):
    """条件中间件装饰器"""
    def decorator(middleware_class):
        def wrapper(app):
            if condition_func():
                return middleware_class(app)
            else:
                # 如果条件不满足,返回直通中间件
                async def passthrough(scope, receive, send):
                    await app(scope, receive, send)
                return passthrough
        return wrapper
    return decorator

@conditional_middleware(lambda: os.getenv("ENABLE_PROFILING") == "true")
class ProfilingMiddleware:
    def __init__(self, app):
        self.app = app
    
    async def __call__(self, scope, receive, send):
        # 性能分析逻辑
        import cProfile
        profiler = cProfile.Profile()
        profiler.enable()
        
        await self.app(scope, receive, send)
        
        profiler.disable()
        # 保存或输出分析结果
        profiler.dump_stats(f"/tmp/profile_{time.time()}.prof")

app.add_middleware(ProfilingMiddleware)

中间件配置

python
from pydantic import BaseSettings

class MiddlewareSettings(BaseSettings):
    enable_cors: bool = True
    cors_origins: list = ["*"]
    enable_gzip: bool = True
    gzip_minimum_size: int = 1000
    enable_rate_limiting: bool = False
    rate_limit_calls: int = 100
    rate_limit_period: int = 60
    
    class Config:
        env_file = ".env"

settings = MiddlewareSettings()

def setup_middleware(app: FastAPI):
    """配置所有中间件"""
    
    # CORS
    if settings.enable_cors:
        app.add_middleware(
            CORSMiddleware,
            allow_origins=settings.cors_origins,
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )
    
    # GZip压缩
    if settings.enable_gzip:
        app.add_middleware(
            GZipMiddleware,
            minimum_size=settings.gzip_minimum_size
        )
    
    # 限流
    if settings.enable_rate_limiting:
        app.add_middleware(
            lambda: RateLimitMiddleware(
                app,
                calls=settings.rate_limit_calls,
                period=settings.rate_limit_period
            )
        )
    
    # 总是添加的中间件
    app.add_middleware(RequestIDMiddleware)
    app.add_middleware(ErrorHandlingMiddleware)

# 设置中间件
setup_middleware(app)

总结

本章详细介绍了FastAPI的中间件系统:

  • 中间件基础:执行流程、基本概念
  • 内置中间件:CORS、HTTPS重定向、GZip等
  • 自定义中间件:请求ID、性能监控、限流、认证
  • 错误处理:全局异常处理中间件
  • 最佳实践:中间件工厂、条件中间件、配置管理

中间件是处理横切关注点的强大工具,合理使用可以大大提升应用的功能性和可维护性。

中间件设计建议

  • 保持中间件职责单一
  • 考虑中间件的执行顺序
  • 提供配置开关控制中间件启用
  • 注意性能影响,避免阻塞操作
  • 做好错误处理和日志记录
  • 编写可测试的中间件代码

在下一章中,我们将学习FastAPI的依赖注入系统,了解如何管理应用的依赖关系。

本站内容仅供学习和研究使用。