Skip to content

PyTorch 模型部署

部署概述

模型部署是将训练好的深度学习模型投入生产环境的关键步骤。PyTorch提供了多种部署方案,从简单的脚本化部署到高性能的生产级服务。

python
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.jit import script, trace
import onnx
import tensorrt as trt

模型序列化与保存

1. 标准模型保存

python
# 保存完整模型(不推荐用于生产)
torch.save(model, 'model_complete.pth')

# 保存模型状态字典(推荐)
torch.save(model.state_dict(), 'model_weights.pth')

# 保存训练检查点
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    'config': model_config
}
torch.save(checkpoint, 'checkpoint.pth')

# 加载模型
def load_model(model_class, weights_path, config):
    model = model_class(config)
    model.load_state_dict(torch.load(weights_path, map_location='cpu'))
    model.eval()
    return model

2. TorchScript部署

python
class ImageClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super(ImageClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 方法1: 脚本化 (Script)
model = ImageClassifier()
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')

# 方法2: 追踪 (Trace)
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
traced_model.save('model_traced.pt')

# 加载TorchScript模型
loaded_model = torch.jit.load('model_scripted.pt')
loaded_model.eval()

# 推理
with torch.no_grad():
    output = loaded_model(example_input)
    print(f"输出形状: {output.shape}")

3. ONNX导出

python
import onnx
import onnxruntime as ort

def export_to_onnx(model, example_input, onnx_path):
    """导出模型到ONNX格式"""
    model.eval()
    
    torch.onnx.export(
        model,                          # 模型
        example_input,                  # 示例输入
        onnx_path,                      # 输出路径
        export_params=True,             # 导出参数
        opset_version=11,               # ONNX算子集版本
        do_constant_folding=True,       # 常量折叠优化
        input_names=['input'],          # 输入名称
        output_names=['output'],        # 输出名称
        dynamic_axes={                  # 动态轴
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
    )
    
    # 验证ONNX模型
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print(f"ONNX模型已保存到: {onnx_path}")

# 使用ONNX Runtime推理
def onnx_inference(onnx_path, input_data):
    """使用ONNX Runtime进行推理"""
    session = ort.InferenceSession(onnx_path)
    
    # 获取输入输出信息
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    
    # 推理
    result = session.run([output_name], {input_name: input_data.numpy()})
    return result[0]

# 导出和测试
model = ImageClassifier()
example_input = torch.randn(1, 3, 224, 224)
export_to_onnx(model, example_input, 'model.onnx')

# ONNX推理测试
onnx_output = onnx_inference('model.onnx', example_input)
print(f"ONNX推理结果形状: {onnx_output.shape}")

Web服务部署

1. Flask部署

python
from flask import Flask, request, jsonify
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import base64

app = Flask(__name__)

# 全局变量
model = None
transform = None
class_names = ['cat', 'dog', 'bird', 'fish', 'horse']

def load_model():
    """加载模型"""
    global model, transform
    
    # 加载模型
    model = torch.jit.load('model_scripted.pt', map_location='cpu')
    model.eval()
    
    # 定义预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])

def preprocess_image(image_bytes):
    """预处理图像"""
    image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
    image_tensor = transform(image).unsqueeze(0)
    return image_tensor

@app.route('/predict', methods=['POST'])
def predict():
    """预测接口"""
    try:
        # 获取图像数据
        if 'image' not in request.files:
            return jsonify({'error': '没有上传图像'}), 400
        
        image_file = request.files['image']
        image_bytes = image_file.read()
        
        # 预处理
        input_tensor = preprocess_image(image_bytes)
        
        # 推理
        with torch.no_grad():
            outputs = model(input_tensor)
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
            
            # 获取top-5预测
            top5_prob, top5_idx = torch.topk(probabilities, 5)
            
            results = []
            for i in range(5):
                results.append({
                    'class': class_names[top5_idx[i].item()],
                    'probability': top5_prob[i].item()
                })
        
        return jsonify({
            'success': True,
            'predictions': results
        })
    
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/health', methods=['GET'])
def health_check():
    """健康检查"""
    return jsonify({'status': 'healthy'})

if __name__ == '__main__':
    load_model()
    app.run(host='0.0.0.0', port=5000, debug=False)

2. FastAPI部署

python
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
from typing import List
import uvicorn

app = FastAPI(title="PyTorch模型API", version="1.0.0")

# 全局变量
model = None
transform = None
class_names = ['cat', 'dog', 'bird', 'fish', 'horse']

@app.on_event("startup")
async def load_model():
    """启动时加载模型"""
    global model, transform
    
    model = torch.jit.load('model_scripted.pt', map_location='cpu')
    model.eval()
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])

class PredictionResponse:
    def __init__(self, class_name: str, probability: float):
        self.class_name = class_name
        self.probability = probability

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    """预测接口"""
    try:
        # 验证文件类型
        if not file.content_type.startswith('image/'):
            raise HTTPException(status_code=400, detail="文件必须是图像格式")
        
        # 读取和预处理图像
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        input_tensor = transform(image).unsqueeze(0)
        
        # 推理
        with torch.no_grad():
            outputs = model(input_tensor)
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
            
            # 获取top-5预测
            top5_prob, top5_idx = torch.topk(probabilities, 5)
            
            predictions = []
            for i in range(5):
                predictions.append({
                    "class_name": class_names[top5_idx[i].item()],
                    "probability": float(top5_prob[i].item())
                })
        
        return {"predictions": predictions}
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """健康检查"""
    return {"status": "healthy"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

容器化部署

1. Docker部署

dockerfile
# Dockerfile
FROM python:3.9-slim

# 设置工作目录
WORKDIR /app

# 安装系统依赖
RUN apt-get update && apt-get install -y \
    libglib2.0-0 \
    libsm6 \
    libxext6 \
    libxrender-dev \
    libgomp1 \
    && rm -rf /var/lib/apt/lists/*

# 复制requirements文件
COPY requirements.txt .

# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt

# 复制应用代码
COPY . .

# 暴露端口
EXPOSE 8000

# 启动命令
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
yaml
# docker-compose.yml
version: '3.8'

services:
  pytorch-api:
    build: .
    ports:
      - "8000:8000"
    volumes:
      - ./models:/app/models
    environment:
      - MODEL_PATH=/app/models/model_scripted.pt
    restart: unless-stopped
    
  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf
    depends_on:
      - pytorch-api
    restart: unless-stopped

2. Kubernetes部署

yaml
# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: pytorch-model-deployment
spec:
  replicas: 3
  selector:
    matchLabels:
      app: pytorch-model
  template:
    metadata:
      labels:
        app: pytorch-model
    spec:
      containers:
      - name: pytorch-api
        image: pytorch-model:latest
        ports:
        - containerPort: 8000
        resources:
          requests:
            memory: "512Mi"
            cpu: "500m"
          limits:
            memory: "1Gi"
            cpu: "1000m"
        env:
        - name: MODEL_PATH
          value: "/app/models/model_scripted.pt"
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 30
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 5
          periodSeconds: 5

---
apiVersion: v1
kind: Service
metadata:
  name: pytorch-model-service
spec:
  selector:
    app: pytorch-model
  ports:
  - protocol: TCP
    port: 80
    targetPort: 8000
  type: LoadBalancer

高性能推理

1. 批量推理优化

python
class BatchInferenceOptimizer:
    def __init__(self, model, max_batch_size=32, timeout=0.1):
        self.model = model
        self.max_batch_size = max_batch_size
        self.timeout = timeout
        self.batch_queue = []
        self.result_futures = []
    
    async def predict(self, input_data):
        """异步批量推理"""
        import asyncio
        from concurrent.futures import Future
        
        future = Future()
        self.batch_queue.append((input_data, future))
        
        # 如果达到批量大小或超时,执行推理
        if len(self.batch_queue) >= self.max_batch_size:
            await self._process_batch()
        else:
            # 设置超时处理
            asyncio.create_task(self._timeout_handler())
        
        return await asyncio.wrap_future(future)
    
    async def _process_batch(self):
        """处理批量数据"""
        if not self.batch_queue:
            return
        
        # 收集批量数据
        batch_data = []
        futures = []
        
        for data, future in self.batch_queue:
            batch_data.append(data)
            futures.append(future)
        
        self.batch_queue.clear()
        
        # 批量推理
        try:
            batch_input = torch.stack(batch_data)
            with torch.no_grad():
                batch_output = self.model(batch_input)
            
            # 分发结果
            for i, future in enumerate(futures):
                future.set_result(batch_output[i])
        
        except Exception as e:
            # 错误处理
            for future in futures:
                future.set_exception(e)
    
    async def _timeout_handler(self):
        """超时处理"""
        import asyncio
        await asyncio.sleep(self.timeout)
        if self.batch_queue:
            await self._process_batch()

2. TensorRT优化

python
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

class TensorRTInference:
    def __init__(self, onnx_path, trt_path=None):
        self.onnx_path = onnx_path
        self.trt_path = trt_path or onnx_path.replace('.onnx', '.trt')
        
        # 构建TensorRT引擎
        self.engine = self._build_engine()
        self.context = self.engine.create_execution_context()
        
        # 分配GPU内存
        self._allocate_buffers()
    
    def _build_engine(self):
        """构建TensorRT引擎"""
        logger = trt.Logger(trt.Logger.WARNING)
        builder = trt.Builder(logger)
        network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
        parser = trt.OnnxParser(network, logger)
        
        # 解析ONNX模型
        with open(self.onnx_path, 'rb') as model:
            if not parser.parse(model.read()):
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                return None
        
        # 配置构建器
        config = builder.create_builder_config()
        config.max_workspace_size = 1 << 30  # 1GB
        config.set_flag(trt.BuilderFlag.FP16)  # 启用FP16
        
        # 构建引擎
        engine = builder.build_engine(network, config)
        
        # 保存引擎
        with open(self.trt_path, 'wb') as f:
            f.write(engine.serialize())
        
        return engine
    
    def _allocate_buffers(self):
        """分配GPU内存缓冲区"""
        self.inputs = []
        self.outputs = []
        self.bindings = []
        
        for binding in self.engine:
            size = trt.volume(self.engine.get_binding_shape(binding)) * self.engine.max_batch_size
            dtype = trt.nptype(self.engine.get_binding_dtype(binding))
            
            # 分配主机和设备内存
            host_mem = cuda.pagelocked_empty(size, dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            
            self.bindings.append(int(device_mem))
            
            if self.engine.binding_is_input(binding):
                self.inputs.append({'host': host_mem, 'device': device_mem})
            else:
                self.outputs.append({'host': host_mem, 'device': device_mem})
    
    def infer(self, input_data):
        """TensorRT推理"""
        # 复制输入数据到GPU
        np.copyto(self.inputs[0]['host'], input_data.ravel())
        cuda.memcpy_htod(self.inputs[0]['device'], self.inputs[0]['host'])
        
        # 执行推理
        self.context.execute_v2(bindings=self.bindings)
        
        # 复制输出数据到CPU
        cuda.memcpy_dtoh(self.outputs[0]['host'], self.outputs[0]['device'])
        
        return self.outputs[0]['host']

边缘设备部署

1. 移动端部署 (PyTorch Mobile)

python
# 模型优化用于移动端
def optimize_for_mobile(model, example_input):
    """优化模型用于移动端部署"""
    model.eval()
    
    # 追踪模型
    traced_model = torch.jit.trace(model, example_input)
    
    # 移动端优化
    from torch.utils.mobile_optimizer import optimize_for_mobile
    optimized_model = optimize_for_mobile(traced_model)
    
    # 保存优化后的模型
    optimized_model._save_for_lite_interpreter("model_mobile.ptl")
    
    return optimized_model

# 使用示例
model = ImageClassifier()
example_input = torch.randn(1, 3, 224, 224)
mobile_model = optimize_for_mobile(model, example_input)

2. 量化部署

python
def quantize_model_for_deployment(model, calibration_loader):
    """量化模型用于部署"""
    # 设置量化配置
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    
    # 准备量化
    model_prepared = torch.quantization.prepare(model, inplace=False)
    
    # 校准
    model_prepared.eval()
    with torch.no_grad():
        for data, _ in calibration_loader:
            model_prepared(data)
    
    # 转换为量化模型
    model_quantized = torch.quantization.convert(model_prepared, inplace=False)
    
    return model_quantized

# 动态量化(更简单的方式)
def dynamic_quantize_model(model):
    """动态量化模型"""
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    return quantized_model

监控与日志

1. 性能监控

python
import time
import psutil
import logging
from functools import wraps

class PerformanceMonitor:
    def __init__(self):
        self.metrics = {
            'request_count': 0,
            'total_inference_time': 0,
            'average_inference_time': 0,
            'memory_usage': 0,
            'cpu_usage': 0
        }
    
    def log_inference_time(self, func):
        """装饰器:记录推理时间"""
        @wraps(func)
        def wrapper(*args, **kwargs):
            start_time = time.time()
            result = func(*args, **kwargs)
            end_time = time.time()
            
            inference_time = end_time - start_time
            self.metrics['request_count'] += 1
            self.metrics['total_inference_time'] += inference_time
            self.metrics['average_inference_time'] = (
                self.metrics['total_inference_time'] / self.metrics['request_count']
            )
            
            logging.info(f"推理时间: {inference_time:.4f}秒")
            return result
        return wrapper
    
    def update_system_metrics(self):
        """更新系统指标"""
        self.metrics['memory_usage'] = psutil.virtual_memory().percent
        self.metrics['cpu_usage'] = psutil.cpu_percent()
    
    def get_metrics(self):
        """获取所有指标"""
        self.update_system_metrics()
        return self.metrics

# 使用示例
monitor = PerformanceMonitor()

@monitor.log_inference_time
def model_inference(input_data):
    with torch.no_grad():
        return model(input_data)

2. 日志配置

python
import logging
import json
from datetime import datetime

class ModelLogger:
    def __init__(self, log_file='model_service.log'):
        self.logger = logging.getLogger('ModelService')
        self.logger.setLevel(logging.INFO)
        
        # 文件处理器
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)
        
        # 控制台处理器
        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        
        # 格式化器
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        file_handler.setFormatter(formatter)
        console_handler.setFormatter(formatter)
        
        self.logger.addHandler(file_handler)
        self.logger.addHandler(console_handler)
    
    def log_prediction(self, input_info, prediction, confidence, inference_time):
        """记录预测日志"""
        log_data = {
            'timestamp': datetime.now().isoformat(),
            'input_info': input_info,
            'prediction': prediction,
            'confidence': confidence,
            'inference_time': inference_time
        }
        
        self.logger.info(f"预测完成: {json.dumps(log_data, ensure_ascii=False)}")
    
    def log_error(self, error_msg, input_info=None):
        """记录错误日志"""
        log_data = {
            'timestamp': datetime.now().isoformat(),
            'error': error_msg,
            'input_info': input_info
        }
        
        self.logger.error(f"预测错误: {json.dumps(log_data, ensure_ascii=False)}")

总结

PyTorch模型部署涵盖了从开发到生产的完整流程:

  1. 模型序列化:TorchScript、ONNX等格式转换
  2. Web服务:Flask、FastAPI等框架的API服务
  3. 容器化:Docker、Kubernetes等容器化部署
  4. 性能优化:批量推理、TensorRT等加速技术
  5. 边缘部署:移动端、量化等轻量化方案
  6. 监控日志:性能监控和完善的日志系统

掌握这些部署技术将帮助你将PyTorch模型成功投入生产环境!

本站内容仅供学习和研究使用。