Skip to content

PyTorch 模型训练与验证

训练流程概述

深度学习模型的训练是一个迭代优化过程,包括前向传播、损失计算、反向传播和参数更新。PyTorch提供了灵活的工具来实现这个过程。

python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# 基本训练循环结构
def basic_training_loop():
    for epoch in range(num_epochs):
        for batch_idx, (data, target) in enumerate(dataloader):
            # 1. 前向传播
            output = model(data)
            
            # 2. 计算损失
            loss = criterion(output, target)
            
            # 3. 清零梯度
            optimizer.zero_grad()
            
            # 4. 反向传播
            loss.backward()
            
            # 5. 更新参数
            optimizer.step()

完整训练框架

1. 训练函数

python
def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """训练一个epoch"""
    model.train()  # 设置为训练模式
    
    running_loss = 0.0
    correct = 0
    total = 0
    
    # 进度条
    from tqdm import tqdm
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}')
    
    for batch_idx, (data, target) in enumerate(pbar):
        # 移动数据到设备
        data, target = data.to(device), target.to(device)
        
        # 清零梯度
        optimizer.zero_grad()
        
        # 前向传播
        output = model(data)
        loss = criterion(output, target)
        
        # 反向传播
        loss.backward()
        
        # 梯度裁剪(可选)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # 更新参数
        optimizer.step()
        
        # 统计
        running_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += target.size(0)
        
        # 更新进度条
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

2. 验证函数

python
def validate_epoch(model, dataloader, criterion, device):
    """验证模型"""
    model.eval()  # 设置为评估模式
    
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():  # 禁用梯度计算
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            loss = criterion(output, target)
            
            running_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    
    val_loss = running_loss / len(dataloader)
    val_acc = 100. * correct / total
    
    return val_loss, val_acc

3. 完整训练流程

python
class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, 
                 device, save_dir='./checkpoints'):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.save_dir = save_dir
        
        # 创建保存目录
        os.makedirs(save_dir, exist_ok=True)
        
        # 训练历史
        self.train_losses = []
        self.train_accs = []
        self.val_losses = []
        self.val_accs = []
        
        # 最佳模型跟踪
        self.best_val_acc = 0.0
        self.best_epoch = 0
    
    def train(self, num_epochs, scheduler=None, early_stopping=None):
        """完整训练流程"""
        print(f"开始训练,共{num_epochs}个epoch")
        print(f"设备: {self.device}")
        print(f"训练集大小: {len(self.train_loader.dataset)}")
        print(f"验证集大小: {len(self.val_loader.dataset)}")
        print("-" * 50)
        
        for epoch in range(num_epochs):
            # 训练
            train_loss, train_acc = train_epoch(
                self.model, self.train_loader, self.criterion, 
                self.optimizer, self.device, epoch
            )
            
            # 验证
            val_loss, val_acc = validate_epoch(
                self.model, self.val_loader, self.criterion, self.device
            )
            
            # 更新学习率
            if scheduler:
                if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(val_loss)
                else:
                    scheduler.step()
            
            # 记录历史
            self.train_losses.append(train_loss)
            self.train_accs.append(train_acc)
            self.val_losses.append(val_loss)
            self.val_accs.append(val_acc)
            
            # 打印结果
            current_lr = self.optimizer.param_groups[0]['lr']
            print(f'Epoch {epoch+1}/{num_epochs}:')
            print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
            print(f'  Learning Rate: {current_lr:.6f}')
            
            # 保存最佳模型
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.best_epoch = epoch
                self.save_checkpoint(epoch, is_best=True)
                print(f'  ✓ 新的最佳模型! 验证准确率: {val_acc:.2f}%')
            
            # 定期保存检查点
            if (epoch + 1) % 10 == 0:
                self.save_checkpoint(epoch)
            
            # 早停检查
            if early_stopping:
                if early_stopping(val_loss, self.model):
                    print(f'早停触发,在第{epoch+1}个epoch停止训练')
                    break
            
            print("-" * 50)
        
        print(f'训练完成! 最佳验证准确率: {self.best_val_acc:.2f}% (Epoch {self.best_epoch+1})')
        
        # 加载最佳模型
        self.load_best_model()
        
        return self.train_losses, self.train_accs, self.val_losses, self.val_accs
    
    def save_checkpoint(self, epoch, is_best=False):
        """保存检查点"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'train_accs': self.train_accs,
            'val_losses': self.val_losses,
            'val_accs': self.val_accs,
            'best_val_acc': self.best_val_acc,
            'best_epoch': self.best_epoch
        }
        
        # 保存当前检查点
        checkpoint_path = os.path.join(self.save_dir, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save(checkpoint, checkpoint_path)
        
        # 保存最佳模型
        if is_best:
            best_path = os.path.join(self.save_dir, 'best_model.pth')
            torch.save(checkpoint, best_path)
    
    def load_best_model(self):
        """加载最佳模型"""
        best_path = os.path.join(self.save_dir, 'best_model.pth')
        if os.path.exists(best_path):
            checkpoint = torch.load(best_path, map_location=self.device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            print(f"已加载最佳模型 (Epoch {checkpoint['best_epoch']+1})")

训练技巧和优化

1. 学习率调度

python
from torch.optim.lr_scheduler import *

def create_scheduler(optimizer, scheduler_type='cosine', **kwargs):
    """创建学习率调度器"""
    if scheduler_type == 'step':
        return StepLR(optimizer, step_size=kwargs.get('step_size', 30), 
                     gamma=kwargs.get('gamma', 0.1))
    
    elif scheduler_type == 'multistep':
        return MultiStepLR(optimizer, milestones=kwargs.get('milestones', [30, 60, 90]), 
                          gamma=kwargs.get('gamma', 0.1))
    
    elif scheduler_type == 'cosine':
        return CosineAnnealingLR(optimizer, T_max=kwargs.get('T_max', 100))
    
    elif scheduler_type == 'plateau':
        return ReduceLROnPlateau(optimizer, mode='min', factor=0.5, 
                               patience=kwargs.get('patience', 10))
    
    elif scheduler_type == 'warmup_cosine':
        return CosineAnnealingWarmRestarts(optimizer, T_0=kwargs.get('T_0', 10))
    
    else:
        raise ValueError(f"不支持的调度器类型: {scheduler_type}")

# 使用示例
scheduler = create_scheduler(optimizer, 'cosine', T_max=100)

2. 早停机制

python
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
    
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
            if self.verbose:
                print(f'验证损失改善到 {val_loss:.6f}')
        else:
            self.counter += 1
            if self.verbose:
                print(f'验证损失未改善 ({self.counter}/{self.patience})')
        
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
                if self.verbose:
                    print('恢复最佳权重')
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

# 使用早停
early_stopping = EarlyStopping(patience=10, min_delta=0.001)

3. 梯度累积

python
def train_with_gradient_accumulation(model, dataloader, criterion, optimizer, 
                                   device, accumulation_steps=4):
    """使用梯度累积的训练"""
    model.train()
    optimizer.zero_grad()
    
    running_loss = 0.0
    
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        
        # 前向传播
        output = model(data)
        loss = criterion(output, target)
        
        # 缩放损失
        loss = loss / accumulation_steps
        
        # 反向传播
        loss.backward()
        
        running_loss += loss.item() * accumulation_steps
        
        # 每accumulation_steps步更新一次参数
        if (batch_idx + 1) % accumulation_steps == 0:
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # 更新参数
            optimizer.step()
            optimizer.zero_grad()
    
    return running_loss / len(dataloader)

4. 混合精度训练

python
from torch.cuda.amp import GradScaler, autocast

def train_with_mixed_precision(model, dataloader, criterion, optimizer, device):
    """混合精度训练"""
    model.train()
    scaler = GradScaler()
    
    running_loss = 0.0
    
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # 使用autocast进行前向传播
        with autocast():
            output = model(data)
            loss = criterion(output, target)
        
        # 缩放损失并反向传播
        scaler.scale(loss).backward()
        
        # 梯度裁剪
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # 更新参数
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
    
    return running_loss / len(dataloader)

模型评估

1. 分类指标

python
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import numpy as np

def evaluate_classification(model, dataloader, device, num_classes):
    """评估分类模型"""
    model.eval()
    
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    # 计算指标
    accuracy = accuracy_score(all_targets, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_targets, all_preds, average='weighted'
    )
    
    # 混淆矩阵
    cm = confusion_matrix(all_targets, all_preds)
    
    print(f"准确率: {accuracy:.4f}")
    print(f"精确率: {precision:.4f}")
    print(f"召回率: {recall:.4f}")
    print(f"F1分数: {f1:.4f}")
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'confusion_matrix': cm
    }

2. 可视化训练过程

python
import matplotlib.pyplot as plt

def plot_training_history(train_losses, train_accs, val_losses, val_accs):
    """绘制训练历史"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # 损失曲线
    ax1.plot(train_losses, label='训练损失', color='blue')
    ax1.plot(val_losses, label='验证损失', color='red')
    ax1.set_title('损失曲线')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # 准确率曲线
    ax2.plot(train_accs, label='训练准确率', color='blue')
    ax2.plot(val_accs, label='验证准确率', color='red')
    ax2.set_title('准确率曲线')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(cm, class_names):
    """绘制混淆矩阵"""
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('混淆矩阵')
    plt.colorbar()
    
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    
    # 添加数值标签
    thresh = cm.max() / 2.
    for i, j in np.ndindex(cm.shape):
        plt.text(j, i, format(cm[i, j], 'd'),
                horizontalalignment="center",
                color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('真实标签')
    plt.xlabel('预测标签')
    plt.tight_layout()
    plt.show()

实际应用示例

1. CIFAR-10图像分类

python
import torchvision
import torchvision.transforms as transforms

# 数据准备
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                       download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                         shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, 
                                      download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, 
                                        shuffle=False, num_workers=2)

# 模型定义
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# 训练设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = create_scheduler(optimizer, 'cosine', T_max=100)

# 创建训练器
trainer = Trainer(model, trainloader, testloader, criterion, optimizer, device)

# 开始训练
train_losses, train_accs, val_losses, val_accs = trainer.train(
    num_epochs=100, 
    scheduler=scheduler,
    early_stopping=EarlyStopping(patience=15)
)

# 可视化结果
plot_training_history(train_losses, train_accs, val_losses, val_accs)

# 最终评估
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']
metrics = evaluate_classification(model, testloader, device, 10)
plot_confusion_matrix(metrics['confusion_matrix'], class_names)

调试和故障排除

1. 常见问题诊断

python
def diagnose_training_issues(model, dataloader, criterion, optimizer, device):
    """诊断训练问题"""
    model.train()
    
    # 检查数据
    data_batch, target_batch = next(iter(dataloader))
    print(f"数据形状: {data_batch.shape}")
    print(f"标签形状: {target_batch.shape}")
    print(f"数据范围: [{data_batch.min():.3f}, {data_batch.max():.3f}]")
    print(f"标签范围: [{target_batch.min()}, {target_batch.max()}]")
    
    # 检查模型输出
    data_batch = data_batch.to(device)
    output = model(data_batch)
    print(f"模型输出形状: {output.shape}")
    print(f"输出范围: [{output.min():.3f}, {output.max():.3f}]")
    
    # 检查损失
    target_batch = target_batch.to(device)
    loss = criterion(output, target_batch)
    print(f"初始损失: {loss.item():.4f}")
    
    # 检查梯度
    loss.backward()
    total_norm = 0
    param_count = 0
    for name, param in model.named_parameters():
        if param.grad is not None:
            param_norm = param.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
            param_count += 1
            if param_norm > 10:  # 梯度过大警告
                print(f"警告: {name} 梯度范数过大: {param_norm:.4f}")
    
    total_norm = total_norm ** (1. / 2)
    print(f"总梯度范数: {total_norm:.4f}")
    print(f"参数数量: {param_count}")
    
    # 检查学习率
    print(f"当前学习率: {optimizer.param_groups[0]['lr']:.6f}")

2. 性能监控

python
import time
import psutil
import GPUtil

class PerformanceMonitor:
    def __init__(self):
        self.start_time = None
        self.epoch_times = []
    
    def start_epoch(self):
        self.start_time = time.time()
    
    def end_epoch(self):
        if self.start_time:
            epoch_time = time.time() - self.start_time
            self.epoch_times.append(epoch_time)
            return epoch_time
        return 0
    
    def get_system_info(self):
        # CPU使用率
        cpu_percent = psutil.cpu_percent()
        
        # 内存使用
        memory = psutil.virtual_memory()
        memory_percent = memory.percent
        memory_used = memory.used / (1024**3)  # GB
        
        # GPU使用率(如果有)
        gpu_info = []
        try:
            gpus = GPUtil.getGPUs()
            for gpu in gpus:
                gpu_info.append({
                    'id': gpu.id,
                    'name': gpu.name,
                    'load': gpu.load * 100,
                    'memory_used': gpu.memoryUsed,
                    'memory_total': gpu.memoryTotal,
                    'temperature': gpu.temperature
                })
        except:
            pass
        
        return {
            'cpu_percent': cpu_percent,
            'memory_percent': memory_percent,
            'memory_used_gb': memory_used,
            'gpu_info': gpu_info
        }
    
    def print_performance_summary(self):
        if self.epoch_times:
            avg_time = sum(self.epoch_times) / len(self.epoch_times)
            print(f"平均每epoch时间: {avg_time:.2f}秒")
            print(f"预计剩余时间: {avg_time * (100 - len(self.epoch_times)):.2f}秒")
        
        system_info = self.get_system_info()
        print(f"CPU使用率: {system_info['cpu_percent']:.1f}%")
        print(f"内存使用率: {system_info['memory_percent']:.1f}%")
        
        for gpu in system_info['gpu_info']:
            print(f"GPU {gpu['id']} ({gpu['name']}): "
                  f"负载 {gpu['load']:.1f}%, "
                  f"显存 {gpu['memory_used']}/{gpu['memory_total']}MB, "
                  f"温度 {gpu['temperature']}°C")

# 使用性能监控
monitor = PerformanceMonitor()

# 在训练循环中使用
for epoch in range(num_epochs):
    monitor.start_epoch()
    
    # 训练代码...
    
    epoch_time = monitor.end_epoch()
    print(f"Epoch {epoch+1} 用时: {epoch_time:.2f}秒")
    
    if epoch % 10 == 0:
        monitor.print_performance_summary()

总结

模型训练是深度学习的核心环节,需要掌握:

  1. 训练流程:理解前向传播、损失计算、反向传播、参数更新的完整流程
  2. 训练技巧:学习率调度、早停、梯度累积、混合精度等优化技术
  3. 模型评估:使用合适的指标评估模型性能
  4. 可视化分析:通过图表分析训练过程和结果
  5. 调试技能:诊断和解决训练中的常见问题

掌握这些技能将帮助你训练出高质量的深度学习模型!

本站内容仅供学习和研究使用。