Skip to content

PyTorch Transformer模型

Transformer简介

Transformer是一种基于自注意力机制的神经网络架构,由Vaswani等人在2017年提出。它彻底改变了自然语言处理领域,成为了BERT、GPT等大型语言模型的基础架构。

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt

# Transformer的核心组件
multihead_attn = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
transformer_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
transformer = nn.TransformerEncoder(transformer_layer, num_layers=6)

位置编码

1. 正弦位置编码

python
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # 计算除数项
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        # 应用正弦和余弦函数
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # 添加批次维度并注册为缓冲区
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x shape: (seq_len, batch_size, d_model)
        return x + self.pe[:x.size(0), :]

2. 可学习位置编码

python
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(LearnablePositionalEncoding, self).__init__()
        self.pe = nn.Parameter(torch.randn(max_len, d_model))
    
    def forward(self, x):
        seq_len = x.size(0)
        return x + self.pe[:seq_len, :].unsqueeze(1)

多头注意力机制

1. 自注意力实现

python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 线性变换层
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Q, K, V shape: (batch_size, num_heads, seq_len, d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        context = torch.matmul(attention_weights, V)
        return context, attention_weights
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, d_model = query.size()
        
        # 线性变换并重塑为多头
        Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # 应用注意力
        context, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # 重塑并连接多头
        context = context.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model
        )
        
        # 输出投影
        output = self.w_o(context)
        
        return output, attention_weights

Transformer编码器

1. 编码器层

python
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        
        # 多头注意力
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        
        # 前馈网络
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        
        # 层归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # 自注意力 + 残差连接 + 层归一化
        attn_output, attention_weights = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # 前馈网络 + 残差连接 + 层归一化
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x, attention_weights

2. 完整编码器

python
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        
        # 词嵌入
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        # 编码器层
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model
    
    def forward(self, x, mask=None):
        # 词嵌入 + 位置编码
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = x.transpose(0, 1)  # (seq_len, batch_size, d_model)
        x = self.pos_encoding(x)
        x = x.transpose(0, 1)  # (batch_size, seq_len, d_model)
        x = self.dropout(x)
        
        # 通过编码器层
        attention_weights = []
        for layer in self.layers:
            x, attn_weights = layer(x, mask)
            attention_weights.append(attn_weights)
        
        return x, attention_weights

实际应用示例

1. 文本分类Transformer

python
class TextClassificationTransformer(nn.Module):
    def __init__(self, vocab_size, num_classes, d_model=512, num_heads=8, num_layers=6):
        super(TextClassificationTransformer, self).__init__()
        
        self.encoder = TransformerEncoder(
            vocab_size, d_model, num_heads, num_layers, d_ff=2048, max_len=5000
        )
        
        # 分类头
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, num_classes)
        )
    
    def forward(self, x, mask=None):
        # 编码
        encoder_output, attention_weights = self.encoder(x, mask)
        
        # 全局平均池化
        if mask is not None:
            mask = mask.squeeze(1).squeeze(1)  # (batch_size, seq_len)
            masked_output = encoder_output * mask.unsqueeze(-1).float()
            pooled = masked_output.sum(dim=1) / mask.sum(dim=1, keepdim=True).float()
        else:
            pooled = encoder_output.mean(dim=1)
        
        # 分类
        logits = self.classifier(pooled)
        
        return logits, attention_weights

总结

Transformer模型是现代深度学习的重要突破,本章介绍了:

  1. 核心组件:位置编码、多头注意力、编码器架构
  2. 完整实现:从基础组件到完整的Transformer模型
  3. 实际应用:文本分类等具体任务
  4. 训练技巧:学习率调度、标签平滑等优化方法

掌握Transformer将为你理解和使用现代大型语言模型打下坚实基础!

本站内容仅供学习和研究使用。