PyTorch 模型部署
部署概述
模型部署是将训练好的深度学习模型投入生产环境的关键步骤。PyTorch提供了多种部署方案,从简单的脚本化部署到高性能的生产级服务。
python
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.jit import script, trace
import onnx
import tensorrt as trt模型序列化与保存
1. 标准模型保存
python
# 保存完整模型(不推荐用于生产)
torch.save(model, 'model_complete.pth')
# 保存模型状态字典(推荐)
torch.save(model.state_dict(), 'model_weights.pth')
# 保存训练检查点
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': model_config
}
torch.save(checkpoint, 'checkpoint.pth')
# 加载模型
def load_model(model_class, weights_path, config):
model = model_class(config)
model.load_state_dict(torch.load(weights_path, map_location='cpu'))
model.eval()
return model2. TorchScript部署
python
class ImageClassifier(nn.Module):
def __init__(self, num_classes=10):
super(ImageClassifier, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 方法1: 脚本化 (Script)
model = ImageClassifier()
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')
# 方法2: 追踪 (Trace)
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
traced_model.save('model_traced.pt')
# 加载TorchScript模型
loaded_model = torch.jit.load('model_scripted.pt')
loaded_model.eval()
# 推理
with torch.no_grad():
output = loaded_model(example_input)
print(f"输出形状: {output.shape}")3. ONNX导出
python
import onnx
import onnxruntime as ort
def export_to_onnx(model, example_input, onnx_path):
"""导出模型到ONNX格式"""
model.eval()
torch.onnx.export(
model, # 模型
example_input, # 示例输入
onnx_path, # 输出路径
export_params=True, # 导出参数
opset_version=11, # ONNX算子集版本
do_constant_folding=True, # 常量折叠优化
input_names=['input'], # 输入名称
output_names=['output'], # 输出名称
dynamic_axes={ # 动态轴
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
# 验证ONNX模型
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print(f"ONNX模型已保存到: {onnx_path}")
# 使用ONNX Runtime推理
def onnx_inference(onnx_path, input_data):
"""使用ONNX Runtime进行推理"""
session = ort.InferenceSession(onnx_path)
# 获取输入输出信息
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 推理
result = session.run([output_name], {input_name: input_data.numpy()})
return result[0]
# 导出和测试
model = ImageClassifier()
example_input = torch.randn(1, 3, 224, 224)
export_to_onnx(model, example_input, 'model.onnx')
# ONNX推理测试
onnx_output = onnx_inference('model.onnx', example_input)
print(f"ONNX推理结果形状: {onnx_output.shape}")Web服务部署
1. Flask部署
python
from flask import Flask, request, jsonify
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import base64
app = Flask(__name__)
# 全局变量
model = None
transform = None
class_names = ['cat', 'dog', 'bird', 'fish', 'horse']
def load_model():
"""加载模型"""
global model, transform
# 加载模型
model = torch.jit.load('model_scripted.pt', map_location='cpu')
model.eval()
# 定义预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def preprocess_image(image_bytes):
"""预处理图像"""
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image_tensor = transform(image).unsqueeze(0)
return image_tensor
@app.route('/predict', methods=['POST'])
def predict():
"""预测接口"""
try:
# 获取图像数据
if 'image' not in request.files:
return jsonify({'error': '没有上传图像'}), 400
image_file = request.files['image']
image_bytes = image_file.read()
# 预处理
input_tensor = preprocess_image(image_bytes)
# 推理
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# 获取top-5预测
top5_prob, top5_idx = torch.topk(probabilities, 5)
results = []
for i in range(5):
results.append({
'class': class_names[top5_idx[i].item()],
'probability': top5_prob[i].item()
})
return jsonify({
'success': True,
'predictions': results
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查"""
return jsonify({'status': 'healthy'})
if __name__ == '__main__':
load_model()
app.run(host='0.0.0.0', port=5000, debug=False)2. FastAPI部署
python
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
from typing import List
import uvicorn
app = FastAPI(title="PyTorch模型API", version="1.0.0")
# 全局变量
model = None
transform = None
class_names = ['cat', 'dog', 'bird', 'fish', 'horse']
@app.on_event("startup")
async def load_model():
"""启动时加载模型"""
global model, transform
model = torch.jit.load('model_scripted.pt', map_location='cpu')
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
class PredictionResponse:
def __init__(self, class_name: str, probability: float):
self.class_name = class_name
self.probability = probability
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""预测接口"""
try:
# 验证文件类型
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="文件必须是图像格式")
# 读取和预处理图像
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
input_tensor = transform(image).unsqueeze(0)
# 推理
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# 获取top-5预测
top5_prob, top5_idx = torch.topk(probabilities, 5)
predictions = []
for i in range(5):
predictions.append({
"class_name": class_names[top5_idx[i].item()],
"probability": float(top5_prob[i].item())
})
return {"predictions": predictions}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)容器化部署
1. Docker部署
dockerfile
# Dockerfile
FROM python:3.9-slim
# 设置工作目录
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender-dev \
libgomp1 \
&& rm -rf /var/lib/apt/lists/*
# 复制requirements文件
COPY requirements.txt .
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 暴露端口
EXPOSE 8000
# 启动命令
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]yaml
# docker-compose.yml
version: '3.8'
services:
pytorch-api:
build: .
ports:
- "8000:8000"
volumes:
- ./models:/app/models
environment:
- MODEL_PATH=/app/models/model_scripted.pt
restart: unless-stopped
nginx:
image: nginx:alpine
ports:
- "80:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
depends_on:
- pytorch-api
restart: unless-stopped2. Kubernetes部署
yaml
# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: pytorch-model-deployment
spec:
replicas: 3
selector:
matchLabels:
app: pytorch-model
template:
metadata:
labels:
app: pytorch-model
spec:
containers:
- name: pytorch-api
image: pytorch-model:latest
ports:
- containerPort: 8000
resources:
requests:
memory: "512Mi"
cpu: "500m"
limits:
memory: "1Gi"
cpu: "1000m"
env:
- name: MODEL_PATH
value: "/app/models/model_scripted.pt"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
name: pytorch-model-service
spec:
selector:
app: pytorch-model
ports:
- protocol: TCP
port: 80
targetPort: 8000
type: LoadBalancer高性能推理
1. 批量推理优化
python
class BatchInferenceOptimizer:
def __init__(self, model, max_batch_size=32, timeout=0.1):
self.model = model
self.max_batch_size = max_batch_size
self.timeout = timeout
self.batch_queue = []
self.result_futures = []
async def predict(self, input_data):
"""异步批量推理"""
import asyncio
from concurrent.futures import Future
future = Future()
self.batch_queue.append((input_data, future))
# 如果达到批量大小或超时,执行推理
if len(self.batch_queue) >= self.max_batch_size:
await self._process_batch()
else:
# 设置超时处理
asyncio.create_task(self._timeout_handler())
return await asyncio.wrap_future(future)
async def _process_batch(self):
"""处理批量数据"""
if not self.batch_queue:
return
# 收集批量数据
batch_data = []
futures = []
for data, future in self.batch_queue:
batch_data.append(data)
futures.append(future)
self.batch_queue.clear()
# 批量推理
try:
batch_input = torch.stack(batch_data)
with torch.no_grad():
batch_output = self.model(batch_input)
# 分发结果
for i, future in enumerate(futures):
future.set_result(batch_output[i])
except Exception as e:
# 错误处理
for future in futures:
future.set_exception(e)
async def _timeout_handler(self):
"""超时处理"""
import asyncio
await asyncio.sleep(self.timeout)
if self.batch_queue:
await self._process_batch()2. TensorRT优化
python
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
class TensorRTInference:
def __init__(self, onnx_path, trt_path=None):
self.onnx_path = onnx_path
self.trt_path = trt_path or onnx_path.replace('.onnx', '.trt')
# 构建TensorRT引擎
self.engine = self._build_engine()
self.context = self.engine.create_execution_context()
# 分配GPU内存
self._allocate_buffers()
def _build_engine(self):
"""构建TensorRT引擎"""
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
# 解析ONNX模型
with open(self.onnx_path, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
# 配置构建器
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30 # 1GB
config.set_flag(trt.BuilderFlag.FP16) # 启用FP16
# 构建引擎
engine = builder.build_engine(network, config)
# 保存引擎
with open(self.trt_path, 'wb') as f:
f.write(engine.serialize())
return engine
def _allocate_buffers(self):
"""分配GPU内存缓冲区"""
self.inputs = []
self.outputs = []
self.bindings = []
for binding in self.engine:
size = trt.volume(self.engine.get_binding_shape(binding)) * self.engine.max_batch_size
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
# 分配主机和设备内存
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
self.bindings.append(int(device_mem))
if self.engine.binding_is_input(binding):
self.inputs.append({'host': host_mem, 'device': device_mem})
else:
self.outputs.append({'host': host_mem, 'device': device_mem})
def infer(self, input_data):
"""TensorRT推理"""
# 复制输入数据到GPU
np.copyto(self.inputs[0]['host'], input_data.ravel())
cuda.memcpy_htod(self.inputs[0]['device'], self.inputs[0]['host'])
# 执行推理
self.context.execute_v2(bindings=self.bindings)
# 复制输出数据到CPU
cuda.memcpy_dtoh(self.outputs[0]['host'], self.outputs[0]['device'])
return self.outputs[0]['host']边缘设备部署
1. 移动端部署 (PyTorch Mobile)
python
# 模型优化用于移动端
def optimize_for_mobile(model, example_input):
"""优化模型用于移动端部署"""
model.eval()
# 追踪模型
traced_model = torch.jit.trace(model, example_input)
# 移动端优化
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_model = optimize_for_mobile(traced_model)
# 保存优化后的模型
optimized_model._save_for_lite_interpreter("model_mobile.ptl")
return optimized_model
# 使用示例
model = ImageClassifier()
example_input = torch.randn(1, 3, 224, 224)
mobile_model = optimize_for_mobile(model, example_input)2. 量化部署
python
def quantize_model_for_deployment(model, calibration_loader):
"""量化模型用于部署"""
# 设置量化配置
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# 准备量化
model_prepared = torch.quantization.prepare(model, inplace=False)
# 校准
model_prepared.eval()
with torch.no_grad():
for data, _ in calibration_loader:
model_prepared(data)
# 转换为量化模型
model_quantized = torch.quantization.convert(model_prepared, inplace=False)
return model_quantized
# 动态量化(更简单的方式)
def dynamic_quantize_model(model):
"""动态量化模型"""
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
return quantized_model监控与日志
1. 性能监控
python
import time
import psutil
import logging
from functools import wraps
class PerformanceMonitor:
def __init__(self):
self.metrics = {
'request_count': 0,
'total_inference_time': 0,
'average_inference_time': 0,
'memory_usage': 0,
'cpu_usage': 0
}
def log_inference_time(self, func):
"""装饰器:记录推理时间"""
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
inference_time = end_time - start_time
self.metrics['request_count'] += 1
self.metrics['total_inference_time'] += inference_time
self.metrics['average_inference_time'] = (
self.metrics['total_inference_time'] / self.metrics['request_count']
)
logging.info(f"推理时间: {inference_time:.4f}秒")
return result
return wrapper
def update_system_metrics(self):
"""更新系统指标"""
self.metrics['memory_usage'] = psutil.virtual_memory().percent
self.metrics['cpu_usage'] = psutil.cpu_percent()
def get_metrics(self):
"""获取所有指标"""
self.update_system_metrics()
return self.metrics
# 使用示例
monitor = PerformanceMonitor()
@monitor.log_inference_time
def model_inference(input_data):
with torch.no_grad():
return model(input_data)2. 日志配置
python
import logging
import json
from datetime import datetime
class ModelLogger:
def __init__(self, log_file='model_service.log'):
self.logger = logging.getLogger('ModelService')
self.logger.setLevel(logging.INFO)
# 文件处理器
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 格式化器
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
def log_prediction(self, input_info, prediction, confidence, inference_time):
"""记录预测日志"""
log_data = {
'timestamp': datetime.now().isoformat(),
'input_info': input_info,
'prediction': prediction,
'confidence': confidence,
'inference_time': inference_time
}
self.logger.info(f"预测完成: {json.dumps(log_data, ensure_ascii=False)}")
def log_error(self, error_msg, input_info=None):
"""记录错误日志"""
log_data = {
'timestamp': datetime.now().isoformat(),
'error': error_msg,
'input_info': input_info
}
self.logger.error(f"预测错误: {json.dumps(log_data, ensure_ascii=False)}")总结
PyTorch模型部署涵盖了从开发到生产的完整流程:
- 模型序列化:TorchScript、ONNX等格式转换
- Web服务:Flask、FastAPI等框架的API服务
- 容器化:Docker、Kubernetes等容器化部署
- 性能优化:批量推理、TensorRT等加速技术
- 边缘部署:移动端、量化等轻量化方案
- 监控日志:性能监控和完善的日志系统
掌握这些部署技术将帮助你将PyTorch模型成功投入生产环境!