Skip to content

PyTorch 分布式训练

分布式训练概述

分布式训练是处理大规模深度学习任务的关键技术,能够利用多个GPU或多台机器来加速训练过程。PyTorch提供了多种分布式训练方案。

分布式训练基础

1. 基本概念

python
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import os

# 分布式训练的关键概念:
# - World Size: 总进程数
# - Rank: 当前进程的全局排名
# - Local Rank: 当前节点内的进程排名
# - Backend: 通信后端 (nccl, gloo, mpi)

def setup_distributed(rank, world_size, backend='nccl'):
    """初始化分布式环境"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # 初始化进程组
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    
    # 设置CUDA设备
    torch.cuda.set_device(rank)

def cleanup_distributed():
    """清理分布式环境"""
    dist.destroy_process_group()

2. 数据并行 (DataParallel)

python
import torch.nn as nn

# 简单的数据并行(单机多GPU)
class SimpleDataParallel:
    def __init__(self, model, device_ids=None):
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        
        self.model = nn.DataParallel(model, device_ids=device_ids)
        self.device_ids = device_ids
    
    def train_step(self, data, target, optimizer, criterion):
        """训练步骤"""
        # 数据会自动分发到多个GPU
        data, target = data.cuda(), target.cuda()
        
        optimizer.zero_grad()
        output = self.model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        return loss.item()

# 使用示例
model = MyModel()
dp_trainer = SimpleDataParallel(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for data, target in dataloader:
    loss = dp_trainer.train_step(data, target, optimizer, criterion)

分布式数据并行 (DDP)

1. 基本DDP实现

python
def train_ddp(rank, world_size, model_class, train_dataset, num_epochs):
    """DDP训练函数"""
    # 设置分布式环境
    setup_distributed(rank, world_size)
    
    # 创建模型并移动到对应GPU
    model = model_class().cuda(rank)
    model = DDP(model, device_ids=[rank])
    
    # 创建分布式采样器
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=world_size, rank=rank
    )
    
    # 创建数据加载器
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True
    )
    
    # 优化器和损失函数
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    # 训练循环
    for epoch in range(num_epochs):
        # 设置采样器的epoch(用于数据打乱)
        train_sampler.set_epoch(epoch)
        
        model.train()
        total_loss = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(rank), target.cuda(rank)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 100 == 0 and rank == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        # 只在主进程中打印和保存
        if rank == 0:
            avg_loss = total_loss / len(train_loader)
            print(f'Epoch {epoch} completed, Average Loss: {avg_loss:.4f}')
            
            # 保存模型
            torch.save(model.module.state_dict(), f'model_epoch_{epoch}.pth')
    
    # 清理
    cleanup_distributed()

# 启动多进程训练
def main():
    world_size = torch.cuda.device_count()
    mp.spawn(
        train_ddp,
        args=(world_size, MyModel, train_dataset, 10),
        nprocs=world_size,
        join=True
    )

if __name__ == '__main__':
    main()

2. 高级DDP训练器

python
class DistributedTrainer:
    def __init__(self, model, train_dataset, val_dataset, config):
        self.config = config
        self.rank = int(os.environ.get('RANK', 0))
        self.local_rank = int(os.environ.get('LOCAL_RANK', 0))
        self.world_size = int(os.environ.get('WORLD_SIZE', 1))
        
        # 设置设备
        torch.cuda.set_device(self.local_rank)
        self.device = torch.device(f'cuda:{self.local_rank}')
        
        # 初始化分布式
        if self.world_size > 1:
            dist.init_process_group(backend='nccl')
        
        # 创建模型
        self.model = model.to(self.device)
        if self.world_size > 1:
            self.model = DDP(self.model, device_ids=[self.local_rank])
        
        # 创建数据加载器
        self.train_loader, self.val_loader = self._create_dataloaders(
            train_dataset, val_dataset
        )
        
        # 优化器和调度器
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), 
            lr=config['learning_rate'] * self.world_size,  # 线性缩放学习率
            weight_decay=config['weight_decay']
        )
        
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=config['num_epochs']
        )
        
        self.criterion = nn.CrossEntropyLoss()
        
        # 训练状态
        self.current_epoch = 0
        self.best_val_acc = 0.0
    
    def _create_dataloaders(self, train_dataset, val_dataset):
        """创建分布式数据加载器"""
        if self.world_size > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset, num_replicas=self.world_size, rank=self.rank
            )
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False
            )
        else:
            train_sampler = None
            val_sampler = None
        
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.config['batch_size'],
            sampler=train_sampler,
            shuffle=(train_sampler is None),
            num_workers=self.config['num_workers'],
            pin_memory=True
        )
        
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=self.config['batch_size'],
            sampler=val_sampler,
            shuffle=False,
            num_workers=self.config['num_workers'],
            pin_memory=True
        )
        
        return train_loader, val_loader
    
    def train_epoch(self):
        """训练一个epoch"""
        self.model.train()
        
        if hasattr(self.train_loader.sampler, 'set_epoch'):
            self.train_loader.sampler.set_epoch(self.current_epoch)
        
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            # 统计
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            
            if batch_idx % 100 == 0 and self.is_main_process():
                print(f'Epoch {self.current_epoch}, Batch {batch_idx}/{len(self.train_loader)}, '
                      f'Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%')
        
        # 聚合所有进程的统计信息
        avg_loss = self._reduce_metric(total_loss / len(self.train_loader))
        accuracy = self._reduce_metric(correct / total)
        
        return avg_loss, accuracy
    
    def validate_epoch(self):
        """验证一个epoch"""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in self.val_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                
                total_loss += loss.item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += target.size(0)
        
        # 聚合验证结果
        avg_loss = self._reduce_metric(total_loss / len(self.val_loader))
        accuracy = self._reduce_metric(correct / total)
        
        return avg_loss, accuracy
    
    def _reduce_metric(self, metric):
        """聚合多个进程的指标"""
        if self.world_size > 1:
            metric_tensor = torch.tensor(metric, device=self.device)
            dist.all_reduce(metric_tensor, op=dist.ReduceOp.SUM)
            return metric_tensor.item() / self.world_size
        return metric
    
    def is_main_process(self):
        """判断是否为主进程"""
        return self.rank == 0
    
    def save_checkpoint(self, filename):
        """保存检查点"""
        if self.is_main_process():
            checkpoint = {
                'epoch': self.current_epoch,
                'model_state_dict': self.model.module.state_dict() if hasattr(self.model, 'module') else self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
                'best_val_acc': self.best_val_acc,
                'config': self.config
            }
            torch.save(checkpoint, filename)
    
    def train(self, num_epochs):
        """完整训练流程"""
        if self.is_main_process():
            print(f"开始分布式训练,使用{self.world_size}个GPU")
        
        for epoch in range(num_epochs):
            self.current_epoch = epoch
            
            # 训练
            train_loss, train_acc = self.train_epoch()
            
            # 验证
            val_loss, val_acc = self.validate_epoch()
            
            # 更新学习率
            self.scheduler.step()
            
            # 主进程打印结果
            if self.is_main_process():
                print(f'Epoch {epoch+1}/{num_epochs}:')
                print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
                print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
                print(f'  LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
                
                # 保存最佳模型
                if val_acc > self.best_val_acc:
                    self.best_val_acc = val_acc
                    self.save_checkpoint('best_model.pth')
                    print(f'  ✓ 新的最佳模型! 验证准确率: {val_acc:.4f}')
                
                print('-' * 60)
        
        # 清理分布式环境
        if self.world_size > 1:
            dist.destroy_process_group()

模型并行

1. 流水线并行

python
class PipelineParallelModel(nn.Module):
    def __init__(self, num_layers, hidden_size, num_devices):
        super(PipelineParallelModel, self).__init__()
        self.num_devices = num_devices
        self.layers_per_device = num_layers // num_devices
        
        # 将层分配到不同设备
        self.device_layers = nn.ModuleList()
        for device_id in range(num_devices):
            device_layers = nn.ModuleList()
            for _ in range(self.layers_per_device):
                device_layers.append(
                    nn.Linear(hidden_size, hidden_size).to(f'cuda:{device_id}')
                )
            self.device_layers.append(device_layers)
    
    def forward(self, x):
        # 流水线前向传播
        for device_id in range(self.num_devices):
            x = x.to(f'cuda:{device_id}')
            for layer in self.device_layers[device_id]:
                x = torch.relu(layer(x))
        
        return x

# 使用流水线并行
pipeline_model = PipelineParallelModel(
    num_layers=12, hidden_size=512, num_devices=4
)

2. 张量并行

python
class TensorParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size, rank):
        super(TensorParallelLinear, self).__init__()
        self.world_size = world_size
        self.rank = rank
        
        # 每个设备只存储部分权重
        self.out_features_per_device = out_features // world_size
        self.weight = nn.Parameter(
            torch.randn(in_features, self.out_features_per_device)
        )
        self.bias = nn.Parameter(
            torch.randn(self.out_features_per_device)
        )
    
    def forward(self, x):
        # 本地计算
        local_output = F.linear(x, self.weight.t(), self.bias)
        
        # 收集所有设备的输出
        output_list = [torch.zeros_like(local_output) for _ in range(self.world_size)]
        dist.all_gather(output_list, local_output)
        
        # 拼接结果
        output = torch.cat(output_list, dim=-1)
        
        return output

混合精度分布式训练

1. 带AMP的DDP训练

python
from torch.cuda.amp import GradScaler, autocast

class AMPDistributedTrainer(DistributedTrainer):
    def __init__(self, model, train_dataset, val_dataset, config):
        super().__init__(model, train_dataset, val_dataset, config)
        self.scaler = GradScaler()
    
    def train_epoch(self):
        """带混合精度的训练epoch"""
        self.model.train()
        
        if hasattr(self.train_loader.sampler, 'set_epoch'):
            self.train_loader.sampler.set_epoch(self.current_epoch)
        
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            
            # 使用autocast进行前向传播
            with autocast():
                output = self.model(data)
                loss = self.criterion(output, target)
            
            # 缩放损失并反向传播
            self.scaler.scale(loss).backward()
            
            # 梯度裁剪
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # 更新参数
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            # 统计
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
        
        avg_loss = self._reduce_metric(total_loss / len(self.train_loader))
        accuracy = self._reduce_metric(correct / total)
        
        return avg_loss, accuracy

大规模训练优化

1. 梯度累积与同步

python
class GradientAccumulationDDP:
    def __init__(self, model, accumulation_steps=4):
        self.model = model
        self.accumulation_steps = accumulation_steps
        self.step_count = 0
    
    def train_step(self, data, target, optimizer, criterion):
        """带梯度累积的训练步骤"""
        # 前向传播
        with autocast():
            output = self.model(data)
            loss = criterion(output, target) / self.accumulation_steps
        
        # 反向传播
        self.scaler.scale(loss).backward()
        
        self.step_count += 1
        
        # 每accumulation_steps步同步一次梯度
        if self.step_count % self.accumulation_steps == 0:
            # 梯度裁剪
            self.scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # 更新参数
            self.scaler.step(optimizer)
            self.scaler.update()
            optimizer.zero_grad()
        
        return loss.item() * self.accumulation_steps

2. 动态损失缩放

python
class DynamicLossScaling:
    def __init__(self, init_scale=2**16, scale_factor=2.0, scale_window=2000):
        self.scale = init_scale
        self.scale_factor = scale_factor
        self.scale_window = scale_window
        self.unskipped_steps = 0
    
    def scale_loss(self, loss):
        """缩放损失"""
        return loss * self.scale
    
    def unscale_gradients(self, optimizer):
        """反缩放梯度"""
        for param_group in optimizer.param_groups:
            for param in param_group['params']:
                if param.grad is not None:
                    param.grad.data.div_(self.scale)
    
    def update_scale(self, found_inf):
        """更新缩放因子"""
        if found_inf:
            # 发现无穷大,减小缩放因子
            self.scale /= self.scale_factor
            self.unskipped_steps = 0
        else:
            # 正常步骤,考虑增加缩放因子
            self.unskipped_steps += 1
            if self.unskipped_steps >= self.scale_window:
                self.scale *= self.scale_factor
                self.unskipped_steps = 0

启动脚本

1. 单机多GPU启动

bash
#!/bin/bash
# launch_single_node.sh

export CUDA_VISIBLE_DEVICES=0,1,2,3
export MASTER_ADDR=localhost
export MASTER_PORT=12355

python -m torch.distributed.launch \
    --nproc_per_node=4 \
    --master_port=12355 \
    train_distributed.py \
    --batch_size=32 \
    --learning_rate=0.001 \
    --num_epochs=100

2. 多机多GPU启动

bash
#!/bin/bash
# launch_multi_node.sh

# 节点0 (主节点)
export CUDA_VISIBLE_DEVICES=0,1,2,3
export MASTER_ADDR=192.168.1.100
export MASTER_PORT=12355

python -m torch.distributed.launch \
    --nproc_per_node=4 \
    --nnodes=2 \
    --node_rank=0 \
    --master_addr=192.168.1.100 \
    --master_port=12355 \
    train_distributed.py

# 节点1
export CUDA_VISIBLE_DEVICES=0,1,2,3
export MASTER_ADDR=192.168.1.100
export MASTER_PORT=12355

python -m torch.distributed.launch \
    --nproc_per_node=4 \
    --nnodes=2 \
    --node_rank=1 \
    --master_addr=192.168.1.100 \
    --master_port=12355 \
    train_distributed.py

3. Slurm集群启动

bash
#!/bin/bash
#SBATCH --job-name=pytorch_distributed
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4
#SBATCH --time=24:00:00

export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export MASTER_PORT=12355
export WORLD_SIZE=$SLURM_NTASKS
export RANK=$SLURM_PROCID
export LOCAL_RANK=$SLURM_LOCALID

srun python train_distributed.py \
    --batch_size=32 \
    --learning_rate=0.001 \
    --num_epochs=100

性能监控和调试

1. 分布式性能监控

python
class DistributedProfiler:
    def __init__(self, rank, world_size):
        self.rank = rank
        self.world_size = world_size
        self.communication_times = []
        self.computation_times = []
    
    def profile_communication(self, tensor):
        """分析通信时间"""
        start_time = time.time()
        
        # 模拟all-reduce操作
        dist.all_reduce(tensor)
        
        end_time = time.time()
        comm_time = end_time - start_time
        self.communication_times.append(comm_time)
        
        return comm_time
    
    def profile_computation(self, model, data):
        """分析计算时间"""
        start_time = time.time()
        
        with torch.no_grad():
            output = model(data)
        
        end_time = time.time()
        comp_time = end_time - start_time
        self.computation_times.append(comp_time)
        
        return comp_time
    
    def get_statistics(self):
        """获取性能统计"""
        if self.rank == 0:
            avg_comm_time = sum(self.communication_times) / len(self.communication_times)
            avg_comp_time = sum(self.computation_times) / len(self.computation_times)
            
            print(f"平均通信时间: {avg_comm_time:.4f}s")
            print(f"平均计算时间: {avg_comp_time:.4f}s")
            print(f"通信/计算比: {avg_comm_time/avg_comp_time:.2f}")

2. 分布式调试工具

python
def debug_distributed_setup():
    """调试分布式设置"""
    if dist.is_initialized():
        print(f"Rank: {dist.get_rank()}")
        print(f"World Size: {dist.get_world_size()}")
        print(f"Backend: {dist.get_backend()}")
        
        # 测试通信
        tensor = torch.ones(1).cuda()
        dist.all_reduce(tensor)
        print(f"All-reduce result: {tensor.item()}")
        
        # 测试广播
        if dist.get_rank() == 0:
            broadcast_tensor = torch.tensor([1.0, 2.0, 3.0]).cuda()
        else:
            broadcast_tensor = torch.zeros(3).cuda()
        
        dist.broadcast(broadcast_tensor, src=0)
        print(f"Broadcast result: {broadcast_tensor}")
    else:
        print("分布式环境未初始化")

# 在训练开始前调用
debug_distributed_setup()

总结

PyTorch分布式训练涵盖了多种并行策略:

  1. 数据并行:DataParallel和DistributedDataParallel
  2. 模型并行:流水线并行和张量并行
  3. 混合精度:结合AMP的分布式训练
  4. 大规模优化:梯度累积、动态损失缩放
  5. 部署方案:单机多GPU、多机多GPU、集群部署
  6. 性能调优:通信优化、计算优化、监控调试

掌握分布式训练技术将帮助你处理大规模深度学习任务,显著提升训练效率!

本站内容仅供学习和研究使用。