Skip to content

PyTorch 模型优化

优化概述

模型优化是深度学习项目中的关键环节,涉及训练效率、推理速度、内存使用和模型大小等多个方面。本章将介绍PyTorch中的各种优化技术。

训练优化

1. 混合精度训练

python
import torch
from torch.cuda.amp import GradScaler, autocast
import torch.nn as nn

class MixedPrecisionTrainer:
    def __init__(self, model, optimizer, criterion):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.scaler = GradScaler()
    
    def train_step(self, data, target):
        """混合精度训练步骤"""
        self.optimizer.zero_grad()
        
        # 使用autocast进行前向传播
        with autocast():
            output = self.model(data)
            loss = self.criterion(output, target)
        
        # 缩放损失并反向传播
        self.scaler.scale(loss).backward()
        
        # 梯度裁剪
        self.scaler.unscale_(self.optimizer)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        
        # 更新参数
        self.scaler.step(self.optimizer)
        self.scaler.update()
        
        return loss.item()

# 使用示例
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

trainer = MixedPrecisionTrainer(model, optimizer, criterion)

# 训练循环
for data, target in dataloader:
    loss = trainer.train_step(data, target)

2. 梯度累积

python
class GradientAccumulator:
    def __init__(self, model, optimizer, criterion, accumulation_steps=4):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.accumulation_steps = accumulation_steps
        self.step_count = 0
    
    def accumulate_step(self, data, target):
        """梯度累积步骤"""
        # 前向传播
        output = self.model(data)
        loss = self.criterion(output, target)
        
        # 缩放损失
        loss = loss / self.accumulation_steps
        
        # 反向传播
        loss.backward()
        
        self.step_count += 1
        
        # 每accumulation_steps步更新一次参数
        if self.step_count % self.accumulation_steps == 0:
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # 更新参数
            self.optimizer.step()
            self.optimizer.zero_grad()
        
        return loss.item() * self.accumulation_steps

# 使用示例
accumulator = GradientAccumulator(model, optimizer, criterion, accumulation_steps=8)

for data, target in dataloader:
    loss = accumulator.accumulate_step(data, target)

3. 学习率调度优化

python
import torch.optim.lr_scheduler as lr_scheduler
import math

class CosineAnnealingWarmRestarts(lr_scheduler._LRScheduler):
    """带预热的余弦退火调度器"""
    def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, warmup_steps=0, last_epoch=-1):
        self.T_0 = T_0
        self.T_mult = T_mult
        self.eta_min = eta_min
        self.warmup_steps = warmup_steps
        super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            # 预热阶段
            return [base_lr * (self.last_epoch + 1) / self.warmup_steps 
                    for base_lr in self.base_lrs]
        else:
            # 余弦退火阶段
            adjusted_epoch = self.last_epoch - self.warmup_steps
            T_cur = adjusted_epoch % self.T_0
            return [self.eta_min + (base_lr - self.eta_min) * 
                    (1 + math.cos(math.pi * T_cur / self.T_0)) / 2
                    for base_lr in self.base_lrs]

# 使用示例
scheduler = CosineAnnealingWarmRestarts(
    optimizer, T_0=50, T_mult=2, eta_min=1e-6, warmup_steps=10
)

# 训练循环中
for epoch in range(num_epochs):
    for data, target in dataloader:
        # 训练步骤
        pass
    scheduler.step()

内存优化

1. 梯度检查点

python
import torch.utils.checkpoint as checkpoint

class CheckpointedModel(nn.Module):
    def __init__(self, layers):
        super(CheckpointedModel, self).__init__()
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x):
        # 使用梯度检查点节省内存
        for layer in self.layers:
            x = checkpoint.checkpoint(layer, x)
        return x

# 或者使用装饰器
def checkpointed_forward(module, input):
    return checkpoint.checkpoint(module, input)

# 在大型模型中使用
class LargeTransformer(nn.Module):
    def __init__(self, config):
        super(LargeTransformer, self).__init__()
        self.layers = nn.ModuleList([
            TransformerLayer(config) for _ in range(config.num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            # 使用检查点节省内存
            x = checkpoint.checkpoint(layer, x)
        return x

2. 内存映射数据集

python
import mmap
import numpy as np
from torch.utils.data import Dataset

class MemoryMappedDataset(Dataset):
    def __init__(self, data_file, index_file):
        # 使用内存映射读取大文件
        self.data_file = open(data_file, 'rb')
        self.data_mmap = mmap.mmap(self.data_file.fileno(), 0, access=mmap.ACCESS_READ)
        
        # 加载索引
        self.indices = np.load(index_file)
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        offset, size = self.indices[idx]
        
        # 从内存映射中读取数据
        self.data_mmap.seek(offset)
        data_bytes = self.data_mmap.read(size)
        
        # 解析数据
        data = self._parse_data(data_bytes)
        
        return data
    
    def _parse_data(self, data_bytes):
        # 实现具体的数据解析逻辑
        pass
    
    def __del__(self):
        if hasattr(self, 'data_mmap'):
            self.data_mmap.close()
        if hasattr(self, 'data_file'):
            self.data_file.close()

3. 动态批量大小

python
class DynamicBatchSampler:
    def __init__(self, dataset, max_tokens=4096, max_batch_size=32):
        self.dataset = dataset
        self.max_tokens = max_tokens
        self.max_batch_size = max_batch_size
    
    def __iter__(self):
        batch = []
        current_tokens = 0
        
        for idx in range(len(self.dataset)):
            sample_length = len(self.dataset[idx])
            
            # 检查是否超过限制
            if (current_tokens + sample_length > self.max_tokens or 
                len(batch) >= self.max_batch_size) and batch:
                yield batch
                batch = []
                current_tokens = 0
            
            batch.append(idx)
            current_tokens += sample_length
        
        if batch:
            yield batch

# 使用动态批量采样器
sampler = DynamicBatchSampler(dataset, max_tokens=4096)
dataloader = DataLoader(dataset, batch_sampler=sampler)

计算优化

1. 模型编译 (PyTorch 2.0+)

python
import torch._dynamo as dynamo

# 编译模型以获得更好的性能
@torch.compile
class OptimizedModel(nn.Module):
    def __init__(self, config):
        super(OptimizedModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(config.input_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, config.output_size)
        )
    
    def forward(self, x):
        return self.layers(x)

# 或者编译现有模型
model = MyModel()
compiled_model = torch.compile(model, mode='max-autotune')

# 不同的编译模式
# 'default': 平衡编译时间和运行时性能
# 'reduce-overhead': 减少Python开销
# 'max-autotune': 最大化性能优化

2. 算子融合

python
# 手动融合常见操作
class FusedLinearReLU(nn.Module):
    def __init__(self, in_features, out_features):
        super(FusedLinearReLU, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
    
    def forward(self, x):
        # 融合线性变换和ReLU激活
        return torch.relu(self.linear(x))

# 使用TorchScript进行自动融合
class ModelForFusion(nn.Module):
    def __init__(self):
        super(ModelForFusion, self).__init__()
        self.conv = nn.Conv2d(3, 64, 3, padding=1)
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# 脚本化模型以启用融合
model = ModelForFusion()
scripted_model = torch.jit.script(model)

# 冻结模型以进行推理优化
scripted_model.eval()
frozen_model = torch.jit.freeze(scripted_model)

3. 并行计算优化

python
# 数据并行
model = nn.DataParallel(model)

# 分布式数据并行
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_distributed():
    dist.init_process_group(backend='nccl')
    torch.cuda.set_device(int(os.environ['LOCAL_RANK']))

def cleanup_distributed():
    dist.destroy_process_group()

# 使用DDP
model = DDP(model, device_ids=[local_rank])

# 模型并行(对于大型模型)
class ModelParallelNet(nn.Module):
    def __init__(self):
        super(ModelParallelNet, self).__init__()
        self.layer1 = nn.Linear(1000, 1000).to('cuda:0')
        self.layer2 = nn.Linear(1000, 1000).to('cuda:1')
        self.layer3 = nn.Linear(1000, 10).to('cuda:1')
    
    def forward(self, x):
        x = x.to('cuda:0')
        x = self.layer1(x)
        x = x.to('cuda:1')
        x = self.layer2(x)
        x = self.layer3(x)
        return x

模型压缩

1. 知识蒸馏

python
class KnowledgeDistillation:
    def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7):
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.temperature = temperature
        self.alpha = alpha
        
        # 冻结教师模型
        for param in self.teacher_model.parameters():
            param.requires_grad = False
    
    def distillation_loss(self, student_logits, teacher_logits, true_labels):
        """计算蒸馏损失"""
        # 软标签损失
        soft_loss = nn.KLDivLoss(reduction='batchmean')(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)
        
        # 硬标签损失
        hard_loss = F.cross_entropy(student_logits, true_labels)
        
        # 组合损失
        total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
        
        return total_loss
    
    def train_step(self, data, target, optimizer):
        """蒸馏训练步骤"""
        self.teacher_model.eval()
        self.student_model.train()
        
        optimizer.zero_grad()
        
        # 教师模型预测
        with torch.no_grad():
            teacher_logits = self.teacher_model(data)
        
        # 学生模型预测
        student_logits = self.student_model(data)
        
        # 计算蒸馏损失
        loss = self.distillation_loss(student_logits, teacher_logits, target)
        
        loss.backward()
        optimizer.step()
        
        return loss.item()

# 使用示例
teacher = LargeModel()  # 大型教师模型
student = SmallModel()  # 小型学生模型

distiller = KnowledgeDistillation(teacher, student)
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)

for data, target in dataloader:
    loss = distiller.train_step(data, target, optimizer)

2. 模型剪枝

python
import torch.nn.utils.prune as prune

class ModelPruner:
    def __init__(self, model):
        self.model = model
    
    def structured_pruning(self, pruning_ratio=0.2):
        """结构化剪枝"""
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                # 按通道剪枝
                prune.ln_structured(
                    module, name='weight', amount=pruning_ratio, 
                    n=2, dim=0  # 剪枝输出通道
                )
            elif isinstance(module, nn.Linear):
                # 按神经元剪枝
                prune.ln_structured(
                    module, name='weight', amount=pruning_ratio,
                    n=2, dim=0
                )
    
    def unstructured_pruning(self, pruning_ratio=0.2):
        """非结构化剪枝"""
        parameters_to_prune = []
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                parameters_to_prune.append((module, 'weight'))
        
        # 全局非结构化剪枝
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=pruning_ratio,
        )
    
    def gradual_pruning(self, initial_sparsity=0.0, final_sparsity=0.8, 
                       pruning_steps=100, pruning_frequency=10):
        """渐进式剪枝"""
        current_step = 0
        
        for epoch in range(pruning_steps):
            if epoch % pruning_frequency == 0:
                # 计算当前稀疏度
                current_sparsity = initial_sparsity + (
                    final_sparsity - initial_sparsity
                ) * (current_step / pruning_steps)
                
                # 应用剪枝
                self.unstructured_pruning(current_sparsity)
                current_step += 1
    
    def remove_pruning(self):
        """移除剪枝重参数化"""
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                try:
                    prune.remove(module, 'weight')
                except:
                    pass
    
    def calculate_sparsity(self):
        """计算模型稀疏度"""
        total_params = 0
        zero_params = 0
        
        for param in self.model.parameters():
            total_params += param.numel()
            zero_params += (param == 0).sum().item()
        
        sparsity = zero_params / total_params
        return sparsity

# 使用示例
pruner = ModelPruner(model)

# 应用剪枝
pruner.unstructured_pruning(pruning_ratio=0.3)

# 计算稀疏度
sparsity = pruner.calculate_sparsity()
print(f"模型稀疏度: {sparsity:.2%}")

# 微调剪枝后的模型
for epoch in range(fine_tune_epochs):
    # 训练循环
    pass

# 移除剪枝重参数化
pruner.remove_pruning()

3. 量化

python
import torch.quantization as quantization

class ModelQuantizer:
    def __init__(self, model):
        self.model = model
    
    def post_training_quantization(self, calibration_loader):
        """训练后量化"""
        # 设置量化配置
        self.model.qconfig = quantization.get_default_qconfig('fbgemm')
        
        # 准备量化
        model_prepared = quantization.prepare(self.model, inplace=False)
        
        # 校准
        model_prepared.eval()
        with torch.no_grad():
            for data, _ in calibration_loader:
                model_prepared(data)
        
        # 转换为量化模型
        model_quantized = quantization.convert(model_prepared, inplace=False)
        
        return model_quantized
    
    def quantization_aware_training(self, train_loader, num_epochs=5):
        """量化感知训练"""
        # 设置QAT配置
        self.model.qconfig = quantization.get_default_qat_qconfig('fbgemm')
        
        # 准备QAT
        model_prepared = quantization.prepare_qat(self.model, inplace=False)
        
        # QAT训练
        optimizer = torch.optim.Adam(model_prepared.parameters(), lr=0.0001)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(num_epochs):
            model_prepared.train()
            for data, target in train_loader:
                optimizer.zero_grad()
                output = model_prepared(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
        
        # 转换为量化模型
        model_prepared.eval()
        model_quantized = quantization.convert(model_prepared, inplace=False)
        
        return model_quantized
    
    def dynamic_quantization(self):
        """动态量化"""
        model_quantized = quantization.quantize_dynamic(
            self.model, {nn.Linear}, dtype=torch.qint8
        )
        return model_quantized

# 使用示例
quantizer = ModelQuantizer(model)

# 动态量化(最简单)
quantized_model = quantizer.dynamic_quantization()

# 训练后量化
# quantized_model = quantizer.post_training_quantization(calibration_loader)

# 量化感知训练
# quantized_model = quantizer.quantization_aware_training(train_loader)

# 比较模型大小
def get_model_size(model):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p")
    os.remove("temp.p")
    return size

original_size = get_model_size(model)
quantized_size = get_model_size(quantized_model)

print(f"原始模型大小: {original_size / 1024 / 1024:.2f} MB")
print(f"量化模型大小: {quantized_size / 1024 / 1024:.2f} MB")
print(f"压缩比: {original_size / quantized_size:.2f}x")

推理优化

1. TorchScript优化

python
# 优化TorchScript模型
def optimize_torchscript_model(model, example_input):
    """优化TorchScript模型"""
    model.eval()
    
    # 追踪模型
    traced_model = torch.jit.trace(model, example_input)
    
    # 优化
    optimized_model = torch.jit.optimize_for_inference(traced_model)
    
    # 冻结模型
    frozen_model = torch.jit.freeze(optimized_model)
    
    return frozen_model

# 使用示例
example_input = torch.randn(1, 3, 224, 224)
optimized_model = optimize_torchscript_model(model, example_input)

# 保存优化后的模型
optimized_model.save('optimized_model.pt')

2. 批量推理优化

python
class BatchInferenceOptimizer:
    def __init__(self, model, max_batch_size=32, timeout=0.1):
        self.model = model
        self.max_batch_size = max_batch_size
        self.timeout = timeout
        self.batch_queue = []
        
    async def predict(self, input_data):
        """异步批量推理"""
        import asyncio
        from concurrent.futures import Future
        
        future = Future()
        self.batch_queue.append((input_data, future))
        
        # 检查是否需要处理批次
        if len(self.batch_queue) >= self.max_batch_size:
            await self._process_batch()
        else:
            # 设置超时处理
            asyncio.create_task(self._timeout_handler())
        
        return await asyncio.wrap_future(future)
    
    async def _process_batch(self):
        """处理批量数据"""
        if not self.batch_queue:
            return
        
        # 收集批量数据
        batch_data = []
        futures = []
        
        for data, future in self.batch_queue:
            batch_data.append(data)
            futures.append(future)
        
        self.batch_queue.clear()
        
        # 批量推理
        try:
            batch_input = torch.stack(batch_data)
            with torch.no_grad():
                batch_output = self.model(batch_input)
            
            # 分发结果
            for i, future in enumerate(futures):
                future.set_result(batch_output[i])
        
        except Exception as e:
            for future in futures:
                future.set_exception(e)
    
    async def _timeout_handler(self):
        """超时处理"""
        import asyncio
        await asyncio.sleep(self.timeout)
        if self.batch_queue:
            await self._process_batch()

性能监控

1. 性能分析器

python
import torch.profiler as profiler

def profile_model(model, input_data, num_steps=100):
    """分析模型性能"""
    model.eval()
    
    with profiler.profile(
        activities=[
            profiler.ProfilerActivity.CPU,
            profiler.ProfilerActivity.CUDA,
        ],
        schedule=profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
        on_trace_ready=profiler.tensorboard_trace_handler('./log/profiler'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof:
        for step in range(num_steps):
            with torch.no_grad():
                output = model(input_data)
            prof.step()
    
    # 打印性能报告
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    
    return prof

# 使用示例
input_data = torch.randn(32, 3, 224, 224).cuda()
prof = profile_model(model.cuda(), input_data)

2. 内存分析

python
def analyze_memory_usage(model, input_data):
    """分析内存使用"""
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    # 记录初始内存
    initial_memory = torch.cuda.memory_allocated()
    
    # 前向传播
    model.eval()
    with torch.no_grad():
        output = model(input_data)
    
    # 记录峰值内存
    peak_memory = torch.cuda.max_memory_allocated()
    final_memory = torch.cuda.memory_allocated()
    
    print(f"初始内存: {initial_memory / 1024**2:.2f} MB")
    print(f"峰值内存: {peak_memory / 1024**2:.2f} MB")
    print(f"最终内存: {final_memory / 1024**2:.2f} MB")
    print(f"内存增长: {(final_memory - initial_memory) / 1024**2:.2f} MB")
    
    return {
        'initial': initial_memory,
        'peak': peak_memory,
        'final': final_memory
    }

# 使用示例
memory_stats = analyze_memory_usage(model.cuda(), input_data.cuda())

总结

PyTorch模型优化涵盖了训练和推理的各个方面:

  1. 训练优化:混合精度、梯度累积、学习率调度
  2. 内存优化:梯度检查点、内存映射、动态批量
  3. 计算优化:模型编译、算子融合、并行计算
  4. 模型压缩:知识蒸馏、剪枝、量化
  5. 推理优化:TorchScript、批量推理
  6. 性能监控:性能分析、内存分析

掌握这些优化技术将显著提升你的深度学习项目的效率和性能!

本站内容仅供学习和研究使用。