Skip to content

PyTorch 图像分类项目

项目概述

本章将通过一个完整的图像分类项目,展示如何使用PyTorch构建端到端的深度学习解决方案。我们将使用CIFAR-10数据集,构建一个能够识别10种不同物体的分类器。

项目结构

image_classification/
├── data/                   # 数据目录
├── models/                 # 模型定义
│   ├── __init__.py
│   ├── resnet.py
│   └── densenet.py
├── utils/                  # 工具函数
│   ├── __init__.py
│   ├── data_loader.py
│   ├── transforms.py
│   └── metrics.py
├── configs/                # 配置文件
│   └── config.yaml
├── checkpoints/            # 模型检查点
├── logs/                   # 训练日志
├── train.py               # 训练脚本
├── test.py                # 测试脚本
├── inference.py           # 推理脚本
└── requirements.txt       # 依赖包

数据准备

1. 数据加载和预处理

python
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np

class CIFAR10DataModule:
    def __init__(self, data_dir='./data', batch_size=128, num_workers=4):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        # CIFAR-10类别名称
        self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                       'dog', 'frog', 'horse', 'ship', 'truck']
        
        # 数据统计信息
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2023, 0.1994, 0.2010)
        
        self.setup_transforms()
    
    def setup_transforms(self):
        """设置数据变换"""
        # 训练时的数据增强
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std),
            transforms.RandomErasing(p=0.1)  # 随机擦除
        ])
        
        # 验证和测试时的变换
        self.val_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])
    
    def prepare_data(self):
        """下载数据"""
        torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True)
        torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        """设置数据集"""
        if stage == 'fit' or stage is None:
            # 训练集
            full_train = torchvision.datasets.CIFAR10(
                root=self.data_dir, train=True, transform=self.train_transform
            )
            
            # 分割训练集和验证集
            train_size = int(0.9 * len(full_train))
            val_size = len(full_train) - train_size
            self.train_dataset, self.val_dataset = random_split(
                full_train, [train_size, val_size]
            )
            
            # 为验证集设置不同的变换
            self.val_dataset.dataset = torchvision.datasets.CIFAR10(
                root=self.data_dir, train=True, transform=self.val_transform
            )
        
        if stage == 'test' or stage is None:
            self.test_dataset = torchvision.datasets.CIFAR10(
                root=self.data_dir, train=False, transform=self.val_transform
            )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers,
            pin_memory=True
        )

# 创建数据模块
data_module = CIFAR10DataModule(batch_size=128)
data_module.prepare_data()
data_module.setup()

2. 数据可视化

python
def visualize_dataset(dataloader, classes, num_samples=16):
    """可视化数据集样本"""
    # 获取一个批次的数据
    data_iter = iter(dataloader)
    images, labels = next(data_iter)
    
    # 反标准化
    mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
    std = torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1)
    images = images * std + mean
    images = torch.clamp(images, 0, 1)
    
    # 创建子图
    fig, axes = plt.subplots(4, 4, figsize=(10, 10))
    for i in range(num_samples):
        row, col = i // 4, i % 4
        axes[row, col].imshow(images[i].permute(1, 2, 0))
        axes[row, col].set_title(f'{classes[labels[i]]}')
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()

# 可视化训练数据
visualize_dataset(data_module.train_dataloader(), data_module.classes)

模型定义

1. 改进的ResNet模型

python
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1, dropout_rate=0.0):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.dropout1 = nn.Dropout2d(dropout_rate)
        
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.dropout2 = nn.Dropout2d(dropout_rate)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    
    def forward(self, x):
        out = self.dropout1(F.relu(self.bn1(self.conv1(x))))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.dropout2(F.relu(out))
        return out

class ImprovedResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, dropout_rate=0.1):
        super(ImprovedResNet, self).__init__()
        self.in_planes = 64
        self.dropout_rate = dropout_rate
        
        # 初始卷积层
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        # ResNet层
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        # 分类头
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        # 权重初始化
        self._initialize_weights()
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, self.dropout_rate))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.dropout(out)
        out = self.fc(out)
        return out

def create_resnet18(num_classes=10, dropout_rate=0.1):
    return ImprovedResNet(BasicBlock, [2, 2, 2, 2], num_classes, dropout_rate)

def create_resnet34(num_classes=10, dropout_rate=0.1):
    return ImprovedResNet(BasicBlock, [3, 4, 6, 3], num_classes, dropout_rate)

2. 集成模型

python
class EnsembleModel(nn.Module):
    def __init__(self, models):
        super(EnsembleModel, self).__init__()
        self.models = nn.ModuleList(models)
    
    def forward(self, x):
        outputs = []
        for model in self.models:
            outputs.append(F.softmax(model(x), dim=1))
        
        # 平均集成
        ensemble_output = torch.stack(outputs).mean(dim=0)
        return torch.log(ensemble_output)  # 返回log概率用于NLLLoss

# 创建集成模型
def create_ensemble():
    model1 = create_resnet18(dropout_rate=0.1)
    model2 = create_resnet18(dropout_rate=0.2)
    model3 = create_resnet34(dropout_rate=0.1)
    return EnsembleModel([model1, model2, model3])

训练框架

1. 训练器类

python
import os
import time
from collections import defaultdict
import torch.optim as optim
from torch.optim.lr_scheduler import *

class ImageClassificationTrainer:
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        # 设备
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        
        # 损失函数
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        
        # 优化器
        self.optimizer = self._create_optimizer()
        
        # 学习率调度器
        self.scheduler = self._create_scheduler()
        
        # 训练历史
        self.history = defaultdict(list)
        
        # 最佳模型跟踪
        self.best_val_acc = 0.0
        self.best_epoch = 0
        
        # 创建保存目录
        os.makedirs(config['save_dir'], exist_ok=True)
    
    def _create_optimizer(self):
        if self.config['optimizer'] == 'adamw':
            return optim.AdamW(
                self.model.parameters(),
                lr=self.config['learning_rate'],
                weight_decay=self.config['weight_decay'],
                betas=(0.9, 0.999)
            )
        elif self.config['optimizer'] == 'sgd':
            return optim.SGD(
                self.model.parameters(),
                lr=self.config['learning_rate'],
                momentum=0.9,
                weight_decay=self.config['weight_decay'],
                nesterov=True
            )
    
    def _create_scheduler(self):
        if self.config['scheduler'] == 'cosine':
            return CosineAnnealingLR(
                self.optimizer, 
                T_max=self.config['epochs'],
                eta_min=1e-6
            )
        elif self.config['scheduler'] == 'step':
            return StepLR(
                self.optimizer,
                step_size=self.config['step_size'],
                gamma=0.1
            )
        elif self.config['scheduler'] == 'plateau':
            return ReduceLROnPlateau(
                self.optimizer,
                mode='max',
                factor=0.5,
                patience=10,
                verbose=True
            )
    
    def train_epoch(self, epoch):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            
            # 前向传播
            output = self.model(data)
            loss = self.criterion(output, target)
            
            # 反向传播
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            # 统计
            running_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            
            # 打印进度
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}/{len(self.train_loader)}, '
                      f'Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%')
        
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = 100. * correct / total
        
        return epoch_loss, epoch_acc
    
    def validate_epoch(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in self.val_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                
                running_loss += loss.item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += target.size(0)
        
        val_loss = running_loss / len(self.val_loader)
        val_acc = 100. * correct / total
        
        return val_loss, val_acc
    
    def train(self):
        print(f"开始训练,设备: {self.device}")
        print(f"模型参数数量: {sum(p.numel() for p in self.model.parameters()):,}")
        
        start_time = time.time()
        
        for epoch in range(self.config['epochs']):
            epoch_start = time.time()
            
            # 训练
            train_loss, train_acc = self.train_epoch(epoch)
            
            # 验证
            val_loss, val_acc = self.validate_epoch()
            
            # 更新学习率
            if isinstance(self.scheduler, ReduceLROnPlateau):
                self.scheduler.step(val_acc)
            else:
                self.scheduler.step()
            
            # 记录历史
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['lr'].append(self.optimizer.param_groups[0]['lr'])
            
            epoch_time = time.time() - epoch_start
            
            # 打印结果
            print(f'Epoch {epoch+1}/{self.config["epochs"]}:')
            print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
            print(f'  LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
            print(f'  Time: {epoch_time:.2f}s')
            
            # 保存最佳模型
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.best_epoch = epoch
                self.save_checkpoint(epoch, is_best=True)
                print(f'  ✓ 新的最佳模型! 验证准确率: {val_acc:.2f}%')
            
            # 定期保存
            if (epoch + 1) % 20 == 0:
                self.save_checkpoint(epoch)
            
            print('-' * 60)
        
        total_time = time.time() - start_time
        print(f'训练完成! 总用时: {total_time/3600:.2f}小时')
        print(f'最佳验证准确率: {self.best_val_acc:.2f}% (Epoch {self.best_epoch+1})')
        
        return self.history
    
    def save_checkpoint(self, epoch, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_val_acc': self.best_val_acc,
            'history': dict(self.history)
        }
        
        if is_best:
            torch.save(checkpoint, os.path.join(self.config['save_dir'], 'best_model.pth'))
        
        torch.save(checkpoint, os.path.join(self.config['save_dir'], f'checkpoint_epoch_{epoch+1}.pth'))

2. 配置文件

python
# 训练配置
config = {
    'epochs': 200,
    'learning_rate': 0.001,
    'weight_decay': 0.01,
    'optimizer': 'adamw',  # 'adamw' or 'sgd'
    'scheduler': 'cosine',  # 'cosine', 'step', 'plateau'
    'step_size': 50,
    'save_dir': './checkpoints',
    'log_dir': './logs'
}

模型评估

1. 详细评估

python
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

def evaluate_model(model, test_loader, device, classes):
    """详细评估模型性能"""
    model.eval()
    
    all_preds = []
    all_targets = []
    all_probs = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            # 获取预测和概率
            probs = F.softmax(output, dim=1)
            pred = output.argmax(dim=1)
            
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # 计算准确率
    accuracy = (np.array(all_preds) == np.array(all_targets)).mean()
    
    # 分类报告
    report = classification_report(all_targets, all_preds, target_names=classes)
    
    # 混淆矩阵
    cm = confusion_matrix(all_targets, all_preds)
    
    print(f"测试准确率: {accuracy:.4f}")
    print("\n分类报告:")
    print(report)
    
    # 绘制混淆矩阵
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    plt.title('混淆矩阵')
    plt.ylabel('真实标签')
    plt.xlabel('预测标签')
    plt.show()
    
    return accuracy, report, cm, all_probs

# 评估模型
accuracy, report, cm, probs = evaluate_model(
    model, data_module.test_dataloader(), device, data_module.classes
)

2. 错误分析

python
def analyze_errors(model, test_loader, device, classes, num_errors=20):
    """分析错误预测的样本"""
    model.eval()
    
    errors = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            
            # 找到错误预测
            wrong_mask = pred != target
            if wrong_mask.any():
                wrong_data = data[wrong_mask]
                wrong_pred = pred[wrong_mask]
                wrong_target = target[wrong_mask]
                wrong_probs = F.softmax(output[wrong_mask], dim=1)
                
                for i in range(len(wrong_data)):
                    errors.append({
                        'image': wrong_data[i],
                        'predicted': wrong_pred[i].item(),
                        'actual': wrong_target[i].item(),
                        'confidence': wrong_probs[i].max().item()
                    })
                    
                    if len(errors) >= num_errors:
                        break
            
            if len(errors) >= num_errors:
                break
    
    # 可视化错误样本
    fig, axes = plt.subplots(4, 5, figsize=(15, 12))
    for i, error in enumerate(errors):
        if i >= 20:
            break
        
        row, col = i // 5, i % 5
        
        # 反标准化图像
        img = error['image'].cpu()
        mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
        std = torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1)
        img = img * std + mean
        img = torch.clamp(img, 0, 1)
        
        axes[row, col].imshow(img.permute(1, 2, 0))
        axes[row, col].set_title(
            f'预测: {classes[error["predicted"]]}\n'
            f'实际: {classes[error["actual"]]}\n'
            f'置信度: {error["confidence"]:.2f}',
            fontsize=8
        )
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()

# 分析错误
analyze_errors(model, data_module.test_dataloader(), device, data_module.classes)

模型推理

1. 单张图片推理

python
def predict_single_image(model, image_path, transform, classes, device):
    """预测单张图片"""
    from PIL import Image
    
    # 加载和预处理图片
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = F.softmax(output, dim=1)
        predicted_class = output.argmax(dim=1).item()
        confidence = probabilities[0][predicted_class].item()
    
    # 获取top-5预测
    top5_prob, top5_idx = torch.topk(probabilities, 5)
    top5_classes = [classes[idx] for idx in top5_idx[0]]
    top5_probs = top5_prob[0].tolist()
    
    return {
        'predicted_class': classes[predicted_class],
        'confidence': confidence,
        'top5_predictions': list(zip(top5_classes, top5_probs))
    }

# 使用示例
# result = predict_single_image(model, 'test_image.jpg', data_module.val_transform, 
#                              data_module.classes, device)
# print(f"预测类别: {result['predicted_class']}")
# print(f"置信度: {result['confidence']:.4f}")

2. 批量推理

python
def batch_inference(model, image_folder, transform, classes, device, batch_size=32):
    """批量推理"""
    from PIL import Image
    import glob
    
    # 获取所有图片路径
    image_paths = glob.glob(os.path.join(image_folder, '*.jpg')) + \
                  glob.glob(os.path.join(image_folder, '*.png'))
    
    results = []
    
    # 批量处理
    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i:i+batch_size]
        batch_images = []
        
        # 加载批次图片
        for path in batch_paths:
            image = Image.open(path).convert('RGB')
            tensor = transform(image)
            batch_images.append(tensor)
        
        # 转换为批次张量
        batch_tensor = torch.stack(batch_images).to(device)
        
        # 推理
        model.eval()
        with torch.no_grad():
            outputs = model(batch_tensor)
            probabilities = F.softmax(outputs, dim=1)
            predictions = outputs.argmax(dim=1)
        
        # 保存结果
        for j, path in enumerate(batch_paths):
            results.append({
                'image_path': path,
                'predicted_class': classes[predictions[j].item()],
                'confidence': probabilities[j][predictions[j]].item()
            })
    
    return results

# 使用示例
# results = batch_inference(model, './test_images', data_module.val_transform, 
#                          data_module.classes, device)

完整训练脚本

python
def main():
    # 设置随机种子
    torch.manual_seed(42)
    np.random.seed(42)
    
    # 创建数据模块
    data_module = CIFAR10DataModule(batch_size=128)
    data_module.prepare_data()
    data_module.setup()
    
    # 创建模型
    model = create_resnet18(num_classes=10, dropout_rate=0.1)
    
    # 训练配置
    config = {
        'epochs': 200,
        'learning_rate': 0.001,
        'weight_decay': 0.01,
        'optimizer': 'adamw',
        'scheduler': 'cosine',
        'save_dir': './checkpoints'
    }
    
    # 创建训练器
    trainer = ImageClassificationTrainer(
        model, 
        data_module.train_dataloader(), 
        data_module.val_dataloader(), 
        config
    )
    
    # 开始训练
    history = trainer.train()
    
    # 加载最佳模型进行测试
    checkpoint = torch.load(os.path.join(config['save_dir'], 'best_model.pth'))
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # 评估模型
    accuracy, report, cm, probs = evaluate_model(
        model, data_module.test_dataloader(), trainer.device, data_module.classes
    )
    
    print(f"最终测试准确率: {accuracy:.4f}")

if __name__ == '__main__':
    main()

总结

本章通过一个完整的图像分类项目,展示了:

  1. 项目结构:如何组织深度学习项目的代码结构
  2. 数据处理:数据加载、预处理、增强的完整流程
  3. 模型设计:改进的ResNet架构和集成方法
  4. 训练框架:完整的训练、验证、保存流程
  5. 模型评估:多种评估指标和错误分析方法
  6. 模型推理:单张和批量推理的实现

这个项目模板可以作为其他图像分类任务的基础,通过修改数据加载和模型结构来适应不同的应用场景。

本站内容仅供学习和研究使用。