Skip to content

PyTorch 自定义操作

自定义操作概述

PyTorch提供了强大的扩展机制,允许开发者创建自定义操作、层和函数。这对于实现新的算法、优化性能或集成第三方库非常有用。

python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import numpy as np
import math

自定义Function

1. 基础自定义Function

python
class SquareFunction(Function):
    """自定义平方函数"""
    
    @staticmethod
    def forward(ctx, input):
        """前向传播"""
        # 保存输入用于反向传播
        ctx.save_for_backward(input)
        return input ** 2
    
    @staticmethod
    def backward(ctx, grad_output):
        """反向传播"""
        # 获取保存的输入
        input, = ctx.saved_tensors
        # 计算梯度: d(x^2)/dx = 2x
        grad_input = grad_output * 2 * input
        return grad_input

# 创建函数接口
def square(input):
    return SquareFunction.apply(input)

# 测试自定义函数
x = torch.randn(3, 4, requires_grad=True)
y = square(x)
loss = y.sum()
loss.backward()

print(f"输入: {x}")
print(f"输出: {y}")
print(f"梯度: {x.grad}")

2. 多输入多输出Function

python
class LinearFunction(Function):
    """自定义线性变换函数"""
    
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        """前向传播: y = xW^T + b"""
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        """反向传播"""
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        
        # 计算输入的梯度
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        
        # 计算权重的梯度
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        
        # 计算偏置的梯度
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        
        return grad_input, grad_weight, grad_bias

def linear(input, weight, bias=None):
    return LinearFunction.apply(input, weight, bias)

# 测试多输入函数
input = torch.randn(5, 3, requires_grad=True)
weight = torch.randn(4, 3, requires_grad=True)
bias = torch.randn(4, requires_grad=True)

output = linear(input, weight, bias)
loss = output.sum()
loss.backward()

print(f"输入梯度形状: {input.grad.shape}")
print(f"权重梯度形状: {weight.grad.shape}")
print(f"偏置梯度形状: {bias.grad.shape}")

3. 带上下文的Function

python
class DropoutFunction(Function):
    """自定义Dropout函数"""
    
    @staticmethod
    def forward(ctx, input, p=0.5, training=True):
        if training:
            # 生成随机掩码
            mask = torch.bernoulli(torch.full_like(input, 1 - p))
            ctx.save_for_backward(mask)
            ctx.p = p
            # 缩放输出以保持期望值不变
            return input * mask / (1 - p)
        else:
            return input
    
    @staticmethod
    def backward(ctx, grad_output):
        mask, = ctx.saved_tensors
        p = ctx.p
        # 应用相同的掩码和缩放
        grad_input = grad_output * mask / (1 - p)
        return grad_input, None, None

def dropout(input, p=0.5, training=True):
    return DropoutFunction.apply(input, p, training)

# 测试Dropout
x = torch.randn(10, 5, requires_grad=True)
y = dropout(x, p=0.3, training=True)
loss = y.sum()
loss.backward()

print(f"输入: {x}")
print(f"Dropout输出: {y}")
print(f"梯度: {x.grad}")

自定义Module

1. 基础自定义Module

python
class CustomLinear(nn.Module):
    """自定义线性层"""
    
    def __init__(self, in_features, out_features, bias=True):
        super(CustomLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # 定义参数
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.randn(out_features))
        else:
            self.register_parameter('bias', None)
        
        # 初始化参数
        self.reset_parameters()
    
    def reset_parameters(self):
        """初始化参数"""
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    
    def forward(self, input):
        return linear(input, self.weight, self.bias)
    
    def extra_repr(self):
        """额外的字符串表示"""
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'

# 测试自定义线性层
custom_linear = CustomLinear(10, 5)
x = torch.randn(3, 10)
y = custom_linear(x)

print(f"自定义线性层: {custom_linear}")
print(f"输出形状: {y.shape}")

2. 复杂自定义Module

python
class MultiHeadSelfAttention(nn.Module):
    """自定义多头自注意力层"""
    
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        assert embed_dim % num_heads == 0
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # 线性变换层
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, embed_dim = x.size()
        
        # 计算Q, K, V
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        # 应用掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 计算注意力权重
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # 应用注意力
        attn_output = torch.matmul(attn_weights, V)
        
        # 重塑并投影输出
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, embed_dim
        )
        output = self.out_proj(attn_output)
        
        return output, attn_weights

# 测试多头自注意力
attention = MultiHeadSelfAttention(embed_dim=256, num_heads=8)
x = torch.randn(2, 10, 256)
output, weights = attention(x)

print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")

3. 带状态的Module

python
class RunningBatchNorm(nn.Module):
    """自定义批量归一化(带运行统计)"""
    
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(RunningBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        # 可学习参数
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        
        # 运行统计(不是参数)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
    
    def forward(self, input):
        if self.training:
            # 训练模式:计算批次统计
            batch_mean = input.mean(dim=0)
            batch_var = input.var(dim=0, unbiased=False)
            
            # 更新运行统计
            with torch.no_grad():
                self.num_batches_tracked += 1
                if self.momentum is None:
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:
                    exponential_average_factor = self.momentum
                
                self.running_mean = (1 - exponential_average_factor) * self.running_mean + \
                                   exponential_average_factor * batch_mean
                self.running_var = (1 - exponential_average_factor) * self.running_var + \
                                  exponential_average_factor * batch_var
            
            # 使用批次统计进行归一化
            normalized = (input - batch_mean) / torch.sqrt(batch_var + self.eps)
        else:
            # 评估模式:使用运行统计
            normalized = (input - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        
        # 应用缩放和偏移
        return self.weight * normalized + self.bias

# 测试自定义批量归一化
bn = RunningBatchNorm(10)
x = torch.randn(32, 10)

# 训练模式
bn.train()
y_train = bn(x)

# 评估模式
bn.eval()
y_eval = bn(x)

print(f"训练模式输出: {y_train.mean():.4f}, {y_train.std():.4f}")
print(f"评估模式输出: {y_eval.mean():.4f}, {y_eval.std():.4f}")

自定义损失函数

1. 基础自定义损失

python
class FocalLoss(nn.Module):
    """Focal Loss用于处理类别不平衡"""
    
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class DiceLoss(nn.Module):
    """Dice Loss用于分割任务"""
    
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, inputs, targets):
        # 将输入转换为概率
        inputs = torch.sigmoid(inputs)
        
        # 展平张量
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        # 计算Dice系数
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
        
        return 1 - dice

class ContrastiveLoss(nn.Module):
    """对比损失用于相似性学习"""
    
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    
    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(euclidean_distance, 2) +
            label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )
        return loss_contrastive

# 测试自定义损失函数
focal_loss = FocalLoss(alpha=1, gamma=2)
dice_loss = DiceLoss()
contrastive_loss = ContrastiveLoss(margin=1.0)

# 测试Focal Loss
logits = torch.randn(10, 5)
targets = torch.randint(0, 5, (10,))
focal_loss_value = focal_loss(logits, targets)
print(f"Focal Loss: {focal_loss_value.item():.4f}")

# 测试Dice Loss
pred_masks = torch.randn(2, 1, 64, 64)
true_masks = torch.randint(0, 2, (2, 1, 64, 64)).float()
dice_loss_value = dice_loss(pred_masks, true_masks)
print(f"Dice Loss: {dice_loss_value.item():.4f}")

自定义优化器

1. 基础自定义优化器

python
class CustomSGD(torch.optim.Optimizer):
    """自定义SGD优化器"""
    
    def __init__(self, params, lr=1e-3, momentum=0, dampening=0, weight_decay=0):
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if momentum < 0.0:
            raise ValueError(f"Invalid momentum value: {momentum}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay)
        super(CustomSGD, self).__init__(params, defaults)
    
    def step(self, closure=None):
        """执行单步优化"""
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                d_p = p.grad.data
                
                # 添加权重衰减
                if weight_decay != 0:
                    d_p = d_p.add(p.data, alpha=weight_decay)
                
                # 添加动量
                if momentum != 0:
                    param_state = self.state[p]
                    if len(param_state) == 0:
                        param_state['momentum_buffer'] = torch.zeros_like(p.data)
                    
                    buf = param_state['momentum_buffer']
                    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    d_p = buf
                
                # 更新参数
                p.data.add_(d_p, alpha=-group['lr'])
        
        return loss

class AdamW(torch.optim.Optimizer):
    """自定义AdamW优化器"""
    
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(AdamW, self).__init__(params, defaults)
    
    def step(self, closure=None):
        """执行单步优化"""
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('AdamW does not support sparse gradients')
                
                state = self.state[p]
                
                # 初始化状态
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
                
                state['step'] += 1
                
                # 指数移动平均
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                # 偏差修正
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                # 计算步长
                step_size = group['lr'] / bias_correction1
                bias_correction2_sqrt = math.sqrt(bias_correction2)
                
                # 权重衰减(解耦)
                p.data.mul_(1 - group['lr'] * group['weight_decay'])
                
                # 更新参数
                denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(group['eps'])
                p.data.addcdiv_(exp_avg, denom, value=-step_size)
        
        return loss

# 测试自定义优化器
model = nn.Linear(10, 1)
custom_sgd = CustomSGD(model.parameters(), lr=0.01, momentum=0.9)
custom_adamw = AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

print(f"自定义SGD: {custom_sgd}")
print(f"自定义AdamW: {custom_adamw}")

自定义数据类型和操作

1. 自定义张量操作

python
class ComplexTensor:
    """自定义复数张量类"""
    
    def __init__(self, real, imag):
        self.real = real
        self.imag = imag
    
    def __add__(self, other):
        if isinstance(other, ComplexTensor):
            return ComplexTensor(self.real + other.real, self.imag + other.imag)
        else:
            return ComplexTensor(self.real + other, self.imag)
    
    def __mul__(self, other):
        if isinstance(other, ComplexTensor):
            # (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
            real = self.real * other.real - self.imag * other.imag
            imag = self.real * other.imag + self.imag * other.real
            return ComplexTensor(real, imag)
        else:
            return ComplexTensor(self.real * other, self.imag * other)
    
    def abs(self):
        """计算复数的模"""
        return torch.sqrt(self.real ** 2 + self.imag ** 2)
    
    def conjugate(self):
        """计算复数的共轭"""
        return ComplexTensor(self.real, -self.imag)
    
    def __repr__(self):
        return f"ComplexTensor(real={self.real}, imag={self.imag})"

# 测试自定义复数张量
real1 = torch.tensor([1.0, 2.0])
imag1 = torch.tensor([3.0, 4.0])
complex1 = ComplexTensor(real1, imag1)

real2 = torch.tensor([5.0, 6.0])
imag2 = torch.tensor([7.0, 8.0])
complex2 = ComplexTensor(real2, imag2)

# 复数运算
result_add = complex1 + complex2
result_mul = complex1 * complex2
result_abs = complex1.abs()
result_conj = complex1.conjugate()

print(f"复数1: {complex1}")
print(f"复数2: {complex2}")
print(f"相加: {result_add}")
print(f"相乘: {result_mul}")
print(f"模: {result_abs}")
print(f"共轭: {result_conj}")

2. 自定义激活函数

python
class Swish(nn.Module):
    """Swish激活函数: f(x) = x * sigmoid(βx)"""
    
    def __init__(self, beta=1.0):
        super(Swish, self).__init__()
        self.beta = nn.Parameter(torch.tensor(beta))
    
    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)

class Mish(nn.Module):
    """Mish激活函数: f(x) = x * tanh(softplus(x))"""
    
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))

class GELU(nn.Module):
    """GELU激活函数的自定义实现"""
    
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class PReLU(nn.Module):
    """参数化ReLU"""
    
    def __init__(self, num_parameters=1, init=0.25):
        super(PReLU, self).__init__()
        self.num_parameters = num_parameters
        self.weight = nn.Parameter(torch.full((num_parameters,), init))
    
    def forward(self, x):
        return F.prelu(x, self.weight)

# 测试自定义激活函数
x = torch.randn(10, 5)

swish = Swish(beta=1.0)
mish = Mish()
gelu = GELU()
prelu = PReLU(num_parameters=5)

print(f"输入: {x[0]}")
print(f"Swish: {swish(x)[0]}")
print(f"Mish: {mish(x)[0]}")
print(f"GELU: {gelu(x)[0]}")
print(f"PReLU: {prelu(x)[0]}")

C++扩展

1. 基础C++扩展

cpp
// custom_ops.cpp
#include <torch/extension.h>
#include <vector>

torch::Tensor add_forward(torch::Tensor input1, torch::Tensor input2) {
    return input1 + input2;
}

std::vector<torch::Tensor> add_backward(torch::Tensor grad_output) {
    return {grad_output, grad_output};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("add_forward", &add_forward, "Add forward");
    m.def("add_backward", &add_backward, "Add backward");
}
python
# setup.py for C++ extension
from setuptools import setup
from pybind11.setup_helpers import Pybind11Extension, build_ext
from torch.utils.cpp_extension import BuildExtension, CppExtension

ext_modules = [
    CppExtension(
        "custom_ops",
        ["custom_ops.cpp"],
    ),
]

setup(
    name="custom_ops",
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
)

2. CUDA扩展示例

python
# 使用JIT编译的CUDA扩展
from torch.utils.cpp_extension import load

cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void add_kernel(float* a, float* b, float* c, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        c[idx] = a[idx] + b[idx];
    }
}

torch::Tensor cuda_add(torch::Tensor a, torch::Tensor b) {
    auto c = torch::zeros_like(a);
    
    int n = a.numel();
    int threads = 256;
    int blocks = (n + threads - 1) / threads;
    
    add_kernel<<<blocks, threads>>>(
        a.data_ptr<float>(),
        b.data_ptr<float>(),
        c.data_ptr<float>(),
        n
    );
    
    return c;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("cuda_add", &cuda_add, "CUDA add");
}
"""

# JIT编译CUDA扩展
cuda_ops = load(
    name="cuda_ops",
    sources=["cuda_ops.cu"],
    verbose=True
)

# 使用CUDA扩展
if torch.cuda.is_available():
    a = torch.randn(1000, device='cuda')
    b = torch.randn(1000, device='cuda')
    c = cuda_ops.cuda_add(a, b)
    print(f"CUDA加法结果: {c[:5]}")

性能优化技巧

1. 内存优化

python
class MemoryEfficientFunction(Function):
    """内存高效的自定义函数"""
    
    @staticmethod
    def forward(ctx, input, weight):
        # 只保存必要的信息用于反向传播
        ctx.input_shape = input.shape
        ctx.weight_shape = weight.shape
        
        # 使用就地操作节省内存
        output = torch.mm(input, weight.t())
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        # 重新计算而不是保存中间结果
        input_shape = ctx.input_shape
        weight_shape = ctx.weight_shape
        
        grad_input = grad_weight = None
        
        if ctx.needs_input_grad[0]:
            grad_input = torch.zeros(input_shape)
            # 计算输入梯度的逻辑
        
        if ctx.needs_input_grad[1]:
            grad_weight = torch.zeros(weight_shape)
            # 计算权重梯度的逻辑
        
        return grad_input, grad_weight

class CheckpointFunction(Function):
    """使用检查点的函数"""
    
    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        check_backward_validity(args)
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
        
        ctx.save_for_backward(*args)
        
        with torch.no_grad():
            outputs = run_function(*args)
        
        return outputs
    
    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad()")
        
        inputs = ctx.saved_tensors
        
        # 恢复RNG状态
        if ctx.preserve_rng_state:
            if ctx.had_cuda_in_fwd:
                set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
            torch.set_rng_state(ctx.fwd_cpu_state)
        
        with torch.enable_grad():
            outputs = ctx.run_function(*inputs)
        
        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        
        torch.autograd.backward(outputs, args)
        
        return (None, None) + tuple(inp.grad for inp in inputs)

def checkpoint(function, *args, **kwargs):
    """检查点包装器"""
    preserve = kwargs.pop('preserve_rng_state', True)
    return CheckpointFunction.apply(function, preserve, *args)

2. 数值稳定性

python
class NumericallyStableFunction(Function):
    """数值稳定的自定义函数"""
    
    @staticmethod
    def forward(ctx, input):
        # 使用数值稳定的实现
        # 例如:log-sum-exp技巧
        max_val = input.max(dim=-1, keepdim=True)[0]
        shifted_input = input - max_val
        exp_shifted = torch.exp(shifted_input)
        sum_exp = exp_shifted.sum(dim=-1, keepdim=True)
        log_sum_exp = torch.log(sum_exp) + max_val
        
        ctx.save_for_backward(exp_shifted, sum_exp)
        return log_sum_exp
    
    @staticmethod
    def backward(ctx, grad_output):
        exp_shifted, sum_exp = ctx.saved_tensors
        # 计算softmax的梯度
        softmax = exp_shifted / sum_exp
        grad_input = grad_output * softmax
        return grad_input

def stable_log_sum_exp(input):
    return NumericallyStableFunction.apply(input)

# 测试数值稳定性
x = torch.tensor([1000.0, 1001.0, 1002.0])  # 大数值
stable_result = stable_log_sum_exp(x)
print(f"数值稳定的结果: {stable_result}")

总结

自定义操作是PyTorch的强大特性,本章介绍了:

  1. 自定义Function:实现前向和反向传播的自定义函数
  2. 自定义Module:创建可重用的神经网络组件
  3. 自定义损失函数:实现特定任务的损失函数
  4. 自定义优化器:实现新的优化算法
  5. 自定义数据类型:扩展PyTorch的数据处理能力
  6. C++/CUDA扩展:高性能的底层实现
  7. 性能优化:内存效率和数值稳定性技巧

掌握这些技术将帮助你扩展PyTorch的功能,实现创新的深度学习算法!

本站内容仅供学习和研究使用。