PyTorch 数据处理
数据处理概述
在深度学习中,数据处理是至关重要的一步。PyTorch提供了强大的数据处理工具,主要包括:
torch.utils.data.Dataset:数据集抽象类torch.utils.data.DataLoader:数据加载器torchvision.transforms:数据变换工具
Dataset类
1. 自定义Dataset
python
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image
import os
class CustomDataset(Dataset):
def __init__(self, data_file, transform=None):
"""
自定义数据集
Args:
data_file: 数据文件路径
transform: 数据变换
"""
self.data = pd.read_csv(data_file)
self.transform = transform
def __len__(self):
"""返回数据集大小"""
return len(self.data)
def __getitem__(self, idx):
"""获取单个样本"""
if torch.is_tensor(idx):
idx = idx.tolist()
# 获取数据
sample = self.data.iloc[idx]
image_path = sample['image_path']
label = sample['label']
# 加载图像
image = Image.open(image_path)
# 应用变换
if self.transform:
image = self.transform(image)
return image, label
# 使用示例
# dataset = CustomDataset('data.csv', transform=transforms.ToTensor())2. 图像数据集示例
python
class ImageDataset(Dataset):
def __init__(self, root_dir, annotations_file, transform=None):
self.root_dir = root_dir
self.annotations = pd.read_csv(annotations_file)
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
image = Image.open(img_path).convert('RGB')
label = self.annotations.iloc[idx, 1]
if self.transform:
image = self.transform(image)
return image, label
# 文本数据集示例
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
label = self.labels[idx]
# 文本编码
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'label': torch.tensor(label, dtype=torch.long)
}3. 内存数据集
python
class MemoryDataset(Dataset):
def __init__(self, data, labels, transform=None):
"""
内存中的数据集(适合小数据集)
"""
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
if self.transform:
sample = self.transform(sample)
return sample, label
# 创建示例数据
data = torch.randn(1000, 3, 32, 32) # 1000个32x32的RGB图像
labels = torch.randint(0, 10, (1000,)) # 10类标签
dataset = MemoryDataset(data, labels)DataLoader
1. 基本用法
python
from torch.utils.data import DataLoader
# 创建数据加载器
dataloader = DataLoader(
dataset,
batch_size=32, # 批量大小
shuffle=True, # 是否打乱数据
num_workers=4, # 并行加载进程数
pin_memory=True, # 是否将数据加载到CUDA固定内存
drop_last=True # 是否丢弃最后不完整的批次
)
# 遍历数据
for batch_idx, (data, target) in enumerate(dataloader):
print(f"批次 {batch_idx}: 数据形状 {data.shape}, 标签形状 {target.shape}")
if batch_idx >= 2: # 只显示前3个批次
break2. 自定义collate函数
python
def custom_collate_fn(batch):
"""
自定义批次整理函数
"""
# 分离数据和标签
data = [item[0] for item in batch]
labels = [item[1] for item in batch]
# 处理变长序列
# 假设data是变长的序列
lengths = [len(seq) for seq in data]
max_length = max(lengths)
# 填充序列
padded_data = []
for seq in data:
padded = torch.zeros(max_length, seq.size(-1))
padded[:len(seq)] = seq
padded_data.append(padded)
return torch.stack(padded_data), torch.tensor(labels), torch.tensor(lengths)
# 使用自定义collate函数
dataloader = DataLoader(
dataset,
batch_size=32,
collate_fn=custom_collate_fn
)3. 分布式数据加载
python
from torch.utils.data.distributed import DistributedSampler
# 分布式训练时的数据加载
def create_distributed_dataloader(dataset, batch_size, world_size, rank):
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=4,
pin_memory=True
)
return dataloader, sampler数据变换 (Transforms)
1. 图像变换
python
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
# 基本变换
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整大小
transforms.RandomHorizontalFlip(0.5), # 随机水平翻转
transforms.RandomRotation(10), # 随机旋转
transforms.ColorJitter( # 颜色抖动
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
transforms.ToTensor(), # 转换为张量
transforms.Normalize( # 标准化
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# 高级变换
advanced_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=3)
], p=0.3),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.RandomErasing(p=0.2) # 随机擦除
])2. 自定义变换
python
class AddGaussianNoise:
"""添加高斯噪声"""
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'
class Cutout:
"""随机遮挡"""
def __init__(self, n_holes, length):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img = img * mask
return img
# 使用自定义变换
custom_transform = transforms.Compose([
transforms.ToTensor(),
AddGaussianNoise(0., 0.1),
Cutout(n_holes=1, length=16)
])3. 数据增强策略
python
# 训练时的数据增强
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 验证时的数据处理
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 测试时增强 (TTA)
tta_transforms = [
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(p=1.0),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
]内置数据集
1. 计算机视觉数据集
python
import torchvision.datasets as datasets
# CIFAR-10
cifar10_train = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=train_transform
)
cifar10_test = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=val_transform
)
# ImageNet
imagenet_train = datasets.ImageNet(
root='./data/imagenet',
split='train',
transform=train_transform
)
# MNIST
mnist_train = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
)2. 自然语言处理数据集
python
import torchtext.datasets as text_datasets
from torchtext.data.utils import get_tokenizer
# IMDB电影评论数据集
train_iter, test_iter = text_datasets.IMDB(split=('train', 'test'))
# 处理文本数据
tokenizer = get_tokenizer('basic_english')
def process_text(text_iter):
data = []
for label, text in text_iter:
tokens = tokenizer(text)
data.append((tokens, label))
return data
train_data = process_text(train_iter)数据预处理技巧
1. 数据标准化
python
def compute_mean_std(dataset):
"""计算数据集的均值和标准差"""
dataloader = DataLoader(dataset, batch_size=100, shuffle=False)
mean = torch.zeros(3)
std = torch.zeros(3)
total_samples = 0
for data, _ in dataloader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
total_samples += batch_samples
mean /= total_samples
std /= total_samples
return mean, std
# 使用示例
# mean, std = compute_mean_std(dataset)
# print(f"Mean: {mean}, Std: {std}")2. 数据平衡
python
from torch.utils.data import WeightedRandomSampler
from collections import Counter
def create_balanced_sampler(dataset):
"""创建平衡采样器"""
# 统计各类别样本数量
labels = [dataset[i][1] for i in range(len(dataset))]
class_counts = Counter(labels)
# 计算权重
total_samples = len(labels)
class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
# 为每个样本分配权重
sample_weights = [class_weights[label] for label in labels]
# 创建采样器
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
return sampler
# 使用平衡采样器
# sampler = create_balanced_sampler(dataset)
# dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)3. 数据分割
python
from torch.utils.data import random_split
def split_dataset(dataset, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
"""分割数据集"""
assert train_ratio + val_ratio + test_ratio == 1.0
total_size = len(dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
dataset, [train_size, val_size, test_size]
)
return train_dataset, val_dataset, test_dataset
# 使用示例
# train_data, val_data, test_data = split_dataset(dataset)高级数据处理
1. 多模态数据处理
python
class MultiModalDataset(Dataset):
def __init__(self, image_paths, texts, labels, image_transform=None, text_tokenizer=None):
self.image_paths = image_paths
self.texts = texts
self.labels = labels
self.image_transform = image_transform
self.text_tokenizer = text_tokenizer
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
# 加载图像
image = Image.open(self.image_paths[idx]).convert('RGB')
if self.image_transform:
image = self.image_transform(image)
# 处理文本
text = self.texts[idx]
if self.text_tokenizer:
text_tokens = self.text_tokenizer(text)
else:
text_tokens = text
label = self.labels[idx]
return {
'image': image,
'text': text_tokens,
'label': label
}2. 在线数据增强
python
class OnlineAugmentation:
def __init__(self, transforms_list, probabilities):
self.transforms = transforms_list
self.probs = probabilities
def __call__(self, image):
for transform, prob in zip(self.transforms, self.probs):
if torch.rand(1) < prob:
image = transform(image)
return image
# 使用在线增强
online_aug = OnlineAugmentation(
transforms_list=[
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1)
],
probabilities=[0.5, 0.3, 0.4]
)3. 缓存机制
python
class CachedDataset(Dataset):
def __init__(self, dataset, cache_size=1000):
self.dataset = dataset
self.cache = {}
self.cache_size = cache_size
self.access_count = {}
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
if idx in self.cache:
self.access_count[idx] += 1
return self.cache[idx]
# 加载数据
data = self.dataset[idx]
# 缓存管理
if len(self.cache) >= self.cache_size:
# 移除最少访问的项
lru_idx = min(self.access_count, key=self.access_count.get)
del self.cache[lru_idx]
del self.access_count[lru_idx]
# 添加到缓存
self.cache[idx] = data
self.access_count[idx] = 1
return data性能优化
1. 数据加载优化
python
# 优化数据加载性能
def create_optimized_dataloader(dataset, batch_size, num_workers=None):
if num_workers is None:
num_workers = min(8, os.cpu_count())
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
persistent_workers=True, # 保持worker进程
prefetch_factor=2, # 预取因子
)
return dataloader2. 内存映射
python
import mmap
class MemoryMappedDataset(Dataset):
def __init__(self, data_file, index_file):
# 使用内存映射读取大文件
self.data_file = open(data_file, 'rb')
self.data_mmap = mmap.mmap(self.data_file.fileno(), 0, access=mmap.ACCESS_READ)
# 加载索引
with open(index_file, 'r') as f:
self.indices = [int(line.strip()) for line in f]
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
offset = self.indices[idx]
# 从内存映射中读取数据
self.data_mmap.seek(offset)
# 读取和解析数据的逻辑
# ...
pass
def __del__(self):
if hasattr(self, 'data_mmap'):
self.data_mmap.close()
if hasattr(self, 'data_file'):
self.data_file.close()实际应用示例
1. 完整的图像分类数据管道
python
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
def create_image_dataloaders(data_dir, batch_size=32, num_workers=4):
# 数据变换
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 创建数据集
train_dataset = datasets.ImageFolder(
root=f'{data_dir}/train',
transform=train_transform
)
val_dataset = datasets.ImageFolder(
root=f'{data_dir}/val',
transform=val_transform
)
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
return train_loader, val_loader, len(train_dataset.classes)
# 使用示例
# train_loader, val_loader, num_classes = create_image_dataloaders('./data')总结
数据处理是深度学习项目的基础,掌握PyTorch的数据处理工具至关重要:
- Dataset类:学会创建自定义数据集,处理不同类型的数据
- DataLoader:掌握批量加载、并行处理、数据打乱等技巧
- 数据变换:熟练使用内置变换和自定义变换进行数据增强
- 性能优化:了解缓存、内存映射、并行加载等优化技术
- 实际应用:能够构建完整的数据处理管道
良好的数据处理不仅能提高模型性能,还能显著加速训练过程!