Skip to content

PyTorch 数据处理

数据处理概述

在深度学习中,数据处理是至关重要的一步。PyTorch提供了强大的数据处理工具,主要包括:

  • torch.utils.data.Dataset:数据集抽象类
  • torch.utils.data.DataLoader:数据加载器
  • torchvision.transforms:数据变换工具

Dataset类

1. 自定义Dataset

python
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image
import os

class CustomDataset(Dataset):
    def __init__(self, data_file, transform=None):
        """
        自定义数据集
        Args:
            data_file: 数据文件路径
            transform: 数据变换
        """
        self.data = pd.read_csv(data_file)
        self.transform = transform
    
    def __len__(self):
        """返回数据集大小"""
        return len(self.data)
    
    def __getitem__(self, idx):
        """获取单个样本"""
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # 获取数据
        sample = self.data.iloc[idx]
        image_path = sample['image_path']
        label = sample['label']
        
        # 加载图像
        image = Image.open(image_path)
        
        # 应用变换
        if self.transform:
            image = self.transform(image)
        
        return image, label

# 使用示例
# dataset = CustomDataset('data.csv', transform=transforms.ToTensor())

2. 图像数据集示例

python
class ImageDataset(Dataset):
    def __init__(self, root_dir, annotations_file, transform=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(annotations_file)
        self.transform = transform
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_path).convert('RGB')
        label = self.annotations.iloc[idx, 1]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# 文本数据集示例
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # 文本编码
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

3. 内存数据集

python
class MemoryDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        """
        内存中的数据集(适合小数据集)
        """
        self.data = data
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample, label

# 创建示例数据
data = torch.randn(1000, 3, 32, 32)  # 1000个32x32的RGB图像
labels = torch.randint(0, 10, (1000,))  # 10类标签

dataset = MemoryDataset(data, labels)

DataLoader

1. 基本用法

python
from torch.utils.data import DataLoader

# 创建数据加载器
dataloader = DataLoader(
    dataset,
    batch_size=32,      # 批量大小
    shuffle=True,       # 是否打乱数据
    num_workers=4,      # 并行加载进程数
    pin_memory=True,    # 是否将数据加载到CUDA固定内存
    drop_last=True      # 是否丢弃最后不完整的批次
)

# 遍历数据
for batch_idx, (data, target) in enumerate(dataloader):
    print(f"批次 {batch_idx}: 数据形状 {data.shape}, 标签形状 {target.shape}")
    if batch_idx >= 2:  # 只显示前3个批次
        break

2. 自定义collate函数

python
def custom_collate_fn(batch):
    """
    自定义批次整理函数
    """
    # 分离数据和标签
    data = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    
    # 处理变长序列
    # 假设data是变长的序列
    lengths = [len(seq) for seq in data]
    max_length = max(lengths)
    
    # 填充序列
    padded_data = []
    for seq in data:
        padded = torch.zeros(max_length, seq.size(-1))
        padded[:len(seq)] = seq
        padded_data.append(padded)
    
    return torch.stack(padded_data), torch.tensor(labels), torch.tensor(lengths)

# 使用自定义collate函数
dataloader = DataLoader(
    dataset,
    batch_size=32,
    collate_fn=custom_collate_fn
)

3. 分布式数据加载

python
from torch.utils.data.distributed import DistributedSampler

# 分布式训练时的数据加载
def create_distributed_dataloader(dataset, batch_size, world_size, rank):
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=4,
        pin_memory=True
    )
    
    return dataloader, sampler

数据变换 (Transforms)

1. 图像变换

python
import torchvision.transforms as transforms
from torchvision.transforms import functional as F

# 基本变换
transform = transforms.Compose([
    transforms.Resize((224, 224)),          # 调整大小
    transforms.RandomHorizontalFlip(0.5),   # 随机水平翻转
    transforms.RandomRotation(10),          # 随机旋转
    transforms.ColorJitter(                 # 颜色抖动
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.1
    ),
    transforms.ToTensor(),                  # 转换为张量
    transforms.Normalize(                   # 标准化
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# 高级变换
advanced_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3)
    ], p=0.3),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.2)  # 随机擦除
])

2. 自定义变换

python
class AddGaussianNoise:
    """添加高斯噪声"""
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'

class Cutout:
    """随机遮挡"""
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img

# 使用自定义变换
custom_transform = transforms.Compose([
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.1),
    Cutout(n_holes=1, length=16)
])

3. 数据增强策略

python
# 训练时的数据增强
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 验证时的数据处理
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 测试时增强 (TTA)
tta_transforms = [
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
]

内置数据集

1. 计算机视觉数据集

python
import torchvision.datasets as datasets

# CIFAR-10
cifar10_train = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=train_transform
)

cifar10_test = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=val_transform
)

# ImageNet
imagenet_train = datasets.ImageNet(
    root='./data/imagenet',
    split='train',
    transform=train_transform
)

# MNIST
mnist_train = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
)

2. 自然语言处理数据集

python
import torchtext.datasets as text_datasets
from torchtext.data.utils import get_tokenizer

# IMDB电影评论数据集
train_iter, test_iter = text_datasets.IMDB(split=('train', 'test'))

# 处理文本数据
tokenizer = get_tokenizer('basic_english')

def process_text(text_iter):
    data = []
    for label, text in text_iter:
        tokens = tokenizer(text)
        data.append((tokens, label))
    return data

train_data = process_text(train_iter)

数据预处理技巧

1. 数据标准化

python
def compute_mean_std(dataset):
    """计算数据集的均值和标准差"""
    dataloader = DataLoader(dataset, batch_size=100, shuffle=False)
    
    mean = torch.zeros(3)
    std = torch.zeros(3)
    total_samples = 0
    
    for data, _ in dataloader:
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        total_samples += batch_samples
    
    mean /= total_samples
    std /= total_samples
    
    return mean, std

# 使用示例
# mean, std = compute_mean_std(dataset)
# print(f"Mean: {mean}, Std: {std}")

2. 数据平衡

python
from torch.utils.data import WeightedRandomSampler
from collections import Counter

def create_balanced_sampler(dataset):
    """创建平衡采样器"""
    # 统计各类别样本数量
    labels = [dataset[i][1] for i in range(len(dataset))]
    class_counts = Counter(labels)
    
    # 计算权重
    total_samples = len(labels)
    class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
    
    # 为每个样本分配权重
    sample_weights = [class_weights[label] for label in labels]
    
    # 创建采样器
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    return sampler

# 使用平衡采样器
# sampler = create_balanced_sampler(dataset)
# dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

3. 数据分割

python
from torch.utils.data import random_split

def split_dataset(dataset, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    """分割数据集"""
    assert train_ratio + val_ratio + test_ratio == 1.0
    
    total_size = len(dataset)
    train_size = int(train_ratio * total_size)
    val_size = int(val_ratio * total_size)
    test_size = total_size - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    return train_dataset, val_dataset, test_dataset

# 使用示例
# train_data, val_data, test_data = split_dataset(dataset)

高级数据处理

1. 多模态数据处理

python
class MultiModalDataset(Dataset):
    def __init__(self, image_paths, texts, labels, image_transform=None, text_tokenizer=None):
        self.image_paths = image_paths
        self.texts = texts
        self.labels = labels
        self.image_transform = image_transform
        self.text_tokenizer = text_tokenizer
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # 加载图像
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
        
        # 处理文本
        text = self.texts[idx]
        if self.text_tokenizer:
            text_tokens = self.text_tokenizer(text)
        else:
            text_tokens = text
        
        label = self.labels[idx]
        
        return {
            'image': image,
            'text': text_tokens,
            'label': label
        }

2. 在线数据增强

python
class OnlineAugmentation:
    def __init__(self, transforms_list, probabilities):
        self.transforms = transforms_list
        self.probs = probabilities
    
    def __call__(self, image):
        for transform, prob in zip(self.transforms, self.probs):
            if torch.rand(1) < prob:
                image = transform(image)
        return image

# 使用在线增强
online_aug = OnlineAugmentation(
    transforms_list=[
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1)
    ],
    probabilities=[0.5, 0.3, 0.4]
)

3. 缓存机制

python
class CachedDataset(Dataset):
    def __init__(self, dataset, cache_size=1000):
        self.dataset = dataset
        self.cache = {}
        self.cache_size = cache_size
        self.access_count = {}
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        if idx in self.cache:
            self.access_count[idx] += 1
            return self.cache[idx]
        
        # 加载数据
        data = self.dataset[idx]
        
        # 缓存管理
        if len(self.cache) >= self.cache_size:
            # 移除最少访问的项
            lru_idx = min(self.access_count, key=self.access_count.get)
            del self.cache[lru_idx]
            del self.access_count[lru_idx]
        
        # 添加到缓存
        self.cache[idx] = data
        self.access_count[idx] = 1
        
        return data

性能优化

1. 数据加载优化

python
# 优化数据加载性能
def create_optimized_dataloader(dataset, batch_size, num_workers=None):
    if num_workers is None:
        num_workers = min(8, os.cpu_count())
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=True,  # 保持worker进程
        prefetch_factor=2,        # 预取因子
    )
    
    return dataloader

2. 内存映射

python
import mmap

class MemoryMappedDataset(Dataset):
    def __init__(self, data_file, index_file):
        # 使用内存映射读取大文件
        self.data_file = open(data_file, 'rb')
        self.data_mmap = mmap.mmap(self.data_file.fileno(), 0, access=mmap.ACCESS_READ)
        
        # 加载索引
        with open(index_file, 'r') as f:
            self.indices = [int(line.strip()) for line in f]
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        offset = self.indices[idx]
        # 从内存映射中读取数据
        self.data_mmap.seek(offset)
        # 读取和解析数据的逻辑
        # ...
        pass
    
    def __del__(self):
        if hasattr(self, 'data_mmap'):
            self.data_mmap.close()
        if hasattr(self, 'data_file'):
            self.data_file.close()

实际应用示例

1. 完整的图像分类数据管道

python
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets

def create_image_dataloaders(data_dir, batch_size=32, num_workers=4):
    # 数据变换
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # 创建数据集
    train_dataset = datasets.ImageFolder(
        root=f'{data_dir}/train',
        transform=train_transform
    )
    
    val_dataset = datasets.ImageFolder(
        root=f'{data_dir}/val',
        transform=val_transform
    )
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader, len(train_dataset.classes)

# 使用示例
# train_loader, val_loader, num_classes = create_image_dataloaders('./data')

总结

数据处理是深度学习项目的基础,掌握PyTorch的数据处理工具至关重要:

  1. Dataset类:学会创建自定义数据集,处理不同类型的数据
  2. DataLoader:掌握批量加载、并行处理、数据打乱等技巧
  3. 数据变换:熟练使用内置变换和自定义变换进行数据增强
  4. 性能优化:了解缓存、内存映射、并行加载等优化技术
  5. 实际应用:能够构建完整的数据处理管道

良好的数据处理不仅能提高模型性能,还能显著加速训练过程!

本站内容仅供学习和研究使用。