PyTorch 文本分类项目
项目概述
本章将通过一个完整的文本分类项目,展示如何使用PyTorch处理自然语言处理任务。我们将构建一个情感分析系统,能够判断文本的情感倾向(正面、负面、中性)。
项目结构
text_classification/
├── data/ # 数据目录
│ ├── raw/ # 原始数据
│ ├── processed/ # 处理后的数据
│ └── vocab/ # 词汇表
├── models/ # 模型定义
│ ├── __init__.py
│ ├── lstm_classifier.py
│ ├── transformer_classifier.py
│ └── cnn_classifier.py
├── utils/ # 工具函数
│ ├── __init__.py
│ ├── data_loader.py
│ ├── text_processor.py
│ └── metrics.py
├── configs/ # 配置文件
├── train.py # 训练脚本
├── evaluate.py # 评估脚本
└── inference.py # 推理脚本数据预处理
1. 文本预处理器
python
import re
import string
import torch
from collections import Counter, defaultdict
from typing import List, Dict, Tuple
import jieba # 中文分词
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
class TextPreprocessor:
def __init__(self, language='en', max_vocab_size=50000, min_freq=2):
self.language = language
self.max_vocab_size = max_vocab_size
self.min_freq = min_freq
# 特殊标记
self.PAD_TOKEN = '<PAD>'
self.UNK_TOKEN = '<UNK>'
self.SOS_TOKEN = '<SOS>'
self.EOS_TOKEN = '<EOS>'
# 词汇表
self.vocab = {}
self.idx2word = {}
self.word_freq = Counter()
# 停用词
if language == 'en':
try:
self.stop_words = set(stopwords.words('english'))
except:
self.stop_words = set()
else:
self.stop_words = set()
def clean_text(self, text: str) -> str:
"""清理文本"""
# 转换为小写
text = text.lower()
# 移除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 移除URL
text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
# 移除邮箱
text = re.sub(r'\S+@\S+', '', text)
# 移除数字(可选)
# text = re.sub(r'\d+', '', text)
# 移除多余的空白字符
text = re.sub(r'\s+', ' ', text).strip()
return text
def tokenize(self, text: str) -> List[str]:
"""分词"""
if self.language == 'zh':
# 中文分词
tokens = list(jieba.cut(text))
else:
# 英文分词
tokens = word_tokenize(text)
# 移除标点符号和停用词
tokens = [
token for token in tokens
if token not in string.punctuation and token not in self.stop_words
]
return tokens
def build_vocab(self, texts: List[str]):
"""构建词汇表"""
print("构建词汇表...")
# 统计词频
for text in texts:
cleaned_text = self.clean_text(text)
tokens = self.tokenize(cleaned_text)
self.word_freq.update(tokens)
# 创建词汇表
vocab_list = [self.PAD_TOKEN, self.UNK_TOKEN, self.SOS_TOKEN, self.EOS_TOKEN]
# 按频率排序,取前max_vocab_size个词
sorted_words = sorted(self.word_freq.items(), key=lambda x: x[1], reverse=True)
for word, freq in sorted_words:
if freq >= self.min_freq and len(vocab_list) < self.max_vocab_size:
vocab_list.append(word)
# 构建词汇表映射
self.vocab = {word: idx for idx, word in enumerate(vocab_list)}
self.idx2word = {idx: word for word, idx in self.vocab.items()}
print(f"词汇表大小: {len(self.vocab)}")
print(f"总词频: {sum(self.word_freq.values())}")
def text_to_sequence(self, text: str, max_length: int = None) -> List[int]:
"""将文本转换为序列"""
cleaned_text = self.clean_text(text)
tokens = self.tokenize(cleaned_text)
# 转换为索引
sequence = [
self.vocab.get(token, self.vocab[self.UNK_TOKEN])
for token in tokens
]
# 截断或填充
if max_length:
if len(sequence) > max_length:
sequence = sequence[:max_length]
else:
sequence.extend([self.vocab[self.PAD_TOKEN]] * (max_length - len(sequence)))
return sequence
def sequence_to_text(self, sequence: List[int]) -> str:
"""将序列转换为文本"""
tokens = [
self.idx2word.get(idx, self.UNK_TOKEN)
for idx in sequence
if idx != self.vocab[self.PAD_TOKEN]
]
return ' '.join(tokens)
def save_vocab(self, filepath: str):
"""保存词汇表"""
import pickle
vocab_data = {
'vocab': self.vocab,
'idx2word': self.idx2word,
'word_freq': self.word_freq,
'config': {
'language': self.language,
'max_vocab_size': self.max_vocab_size,
'min_freq': self.min_freq
}
}
with open(filepath, 'wb') as f:
pickle.dump(vocab_data, f)
def load_vocab(self, filepath: str):
"""加载词汇表"""
import pickle
with open(filepath, 'rb') as f:
vocab_data = pickle.load(f)
self.vocab = vocab_data['vocab']
self.idx2word = vocab_data['idx2word']
self.word_freq = vocab_data['word_freq']
config = vocab_data['config']
self.language = config['language']
self.max_vocab_size = config['max_vocab_size']
self.min_freq = config['min_freq']
# 使用示例
preprocessor = TextPreprocessor(language='en')
# 示例文本
texts = [
"I love this movie! It's absolutely fantastic.",
"This film is terrible. I hate it.",
"The movie is okay, not great but not bad either."
]
# 构建词汇表
preprocessor.build_vocab(texts)
# 文本转序列
sequence = preprocessor.text_to_sequence(texts[0], max_length=20)
print(f"原文: {texts[0]}")
print(f"序列: {sequence}")
print(f"还原: {preprocessor.sequence_to_text(sequence)}")2. 数据集类
python
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
class TextClassificationDataset(Dataset):
def __init__(self, texts, labels, preprocessor, max_length=128):
self.texts = texts
self.labels = labels
self.preprocessor = preprocessor
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
# 文本转序列
sequence = self.preprocessor.text_to_sequence(text, self.max_length)
return {
'input_ids': torch.tensor(sequence, dtype=torch.long),
'label': torch.tensor(label, dtype=torch.long),
'text': text
}
def create_data_loaders(train_texts, train_labels, val_texts, val_labels,
preprocessor, batch_size=32, max_length=128):
"""创建数据加载器"""
train_dataset = TextClassificationDataset(
train_texts, train_labels, preprocessor, max_length
)
val_dataset = TextClassificationDataset(
val_texts, val_labels, preprocessor, max_length
)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True
)
return train_loader, val_loader
# 示例数据加载
def load_imdb_data():
"""加载IMDB数据集示例"""
# 这里使用示例数据,实际项目中需要加载真实数据
train_texts = [
"I love this movie! It's absolutely fantastic.",
"This film is terrible. I hate it.",
"The movie is okay, not great but not bad either.",
"Amazing cinematography and great acting!",
"Boring and predictable plot."
] * 1000 # 扩展数据
train_labels = [1, 0, 2, 1, 0] * 1000 # 0: 负面, 1: 正面, 2: 中性
val_texts = train_texts[:500]
val_labels = train_labels[:500]
return train_texts, train_labels, val_texts, val_labels模型架构
1. LSTM分类器
python
import torch.nn as nn
import torch.nn.functional as F
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes,
num_layers=2, dropout=0.3, bidirectional=True):
super(LSTMClassifier, self).__init__()
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.bidirectional = bidirectional
# 词嵌入层
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# LSTM层
self.lstm = nn.LSTM(
embed_dim, hidden_dim, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional
)
# 注意力机制
lstm_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
self.attention = nn.Linear(lstm_output_dim, 1)
# 分类器
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(lstm_output_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, input_ids, attention_mask=None):
# 词嵌入
embedded = self.embedding(input_ids) # (batch_size, seq_len, embed_dim)
# LSTM
lstm_out, (hidden, cell) = self.lstm(embedded) # (batch_size, seq_len, hidden_dim*2)
# 注意力机制
if attention_mask is not None:
# 创建注意力掩码
attention_weights = self.attention(lstm_out).squeeze(-1) # (batch_size, seq_len)
attention_weights = attention_weights.masked_fill(attention_mask == 0, -1e9)
attention_weights = F.softmax(attention_weights, dim=1)
# 加权平均
context = torch.sum(attention_weights.unsqueeze(-1) * lstm_out, dim=1)
else:
# 简单平均池化
context = torch.mean(lstm_out, dim=1)
# 分类
logits = self.classifier(context)
return logits
# 创建模型
def create_lstm_model(vocab_size, num_classes):
model = LSTMClassifier(
vocab_size=vocab_size,
embed_dim=128,
hidden_dim=256,
num_classes=num_classes,
num_layers=2,
dropout=0.3,
bidirectional=True
)
return model2. CNN分类器
python
class CNNClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes,
num_classes, dropout=0.3):
super(CNNClassifier, self).__init__()
# 词嵌入层
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# 多个卷积层
self.convs = nn.ModuleList([
nn.Conv1d(embed_dim, num_filters, kernel_size=fs)
for fs in filter_sizes
])
# 分类器
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(len(filter_sizes) * num_filters, 128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, num_classes)
)
def forward(self, input_ids):
# 词嵌入
embedded = self.embedding(input_ids) # (batch_size, seq_len, embed_dim)
embedded = embedded.transpose(1, 2) # (batch_size, embed_dim, seq_len)
# 卷积和池化
conv_outputs = []
for conv in self.convs:
conv_out = F.relu(conv(embedded)) # (batch_size, num_filters, conv_seq_len)
pooled = F.max_pool1d(conv_out, conv_out.size(2)) # (batch_size, num_filters, 1)
conv_outputs.append(pooled.squeeze(2))
# 拼接所有卷积输出
concat_output = torch.cat(conv_outputs, dim=1) # (batch_size, len(filter_sizes) * num_filters)
# 分类
logits = self.classifier(concat_output)
return logits
def create_cnn_model(vocab_size, num_classes):
model = CNNClassifier(
vocab_size=vocab_size,
embed_dim=128,
num_filters=100,
filter_sizes=[3, 4, 5],
num_classes=num_classes,
dropout=0.3
)
return model3. Transformer分类器
python
class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, num_heads, num_layers,
num_classes, max_length=512, dropout=0.1):
super(TransformerClassifier, self).__init__()
self.embed_dim = embed_dim
self.max_length = max_length
# 词嵌入和位置编码
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.pos_encoding = nn.Parameter(torch.randn(max_length, embed_dim))
# Transformer编码器
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
# 分类器
self.classifier = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(embed_dim // 2, num_classes)
)
def forward(self, input_ids, attention_mask=None):
seq_len = input_ids.size(1)
# 词嵌入 + 位置编码
embedded = self.embedding(input_ids)
embedded += self.pos_encoding[:seq_len, :].unsqueeze(0)
# 创建填充掩码
if attention_mask is None:
attention_mask = (input_ids != 0)
# Transformer编码
transformer_out = self.transformer(
embedded,
src_key_padding_mask=~attention_mask
)
# 全局平均池化(忽略填充位置)
mask_expanded = attention_mask.unsqueeze(-1).float()
sum_embeddings = torch.sum(transformer_out * mask_expanded, dim=1)
sum_mask = torch.sum(mask_expanded, dim=1)
pooled = sum_embeddings / sum_mask
# 分类
logits = self.classifier(pooled)
return logits
def create_transformer_model(vocab_size, num_classes):
model = TransformerClassifier(
vocab_size=vocab_size,
embed_dim=256,
num_heads=8,
num_layers=6,
num_classes=num_classes,
max_length=512,
dropout=0.1
)
return model训练框架
1. 训练器
python
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np
class TextClassificationTrainer:
def __init__(self, model, train_loader, val_loader, num_classes, device):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.num_classes = num_classes
self.device = device
# 损失函数和优化器
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', patience=3, factor=0.5)
# 训练历史
self.train_losses = []
self.train_accs = []
self.val_losses = []
self.val_accs = []
# 最佳模型
self.best_val_acc = 0.0
def train_epoch(self):
"""训练一个epoch"""
self.model.train()
total_loss = 0
all_preds = []
all_labels = []
for batch in self.train_loader:
input_ids = batch['input_ids'].to(self.device)
labels = batch['label'].to(self.device)
# 创建注意力掩码
attention_mask = (input_ids != 0)
self.optimizer.zero_grad()
# 前向传播
if isinstance(self.model, TransformerClassifier):
logits = self.model(input_ids, attention_mask)
else:
logits = self.model(input_ids)
loss = self.criterion(logits, labels)
# 反向传播
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# 统计
total_loss += loss.item()
preds = torch.argmax(logits, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
avg_loss = total_loss / len(self.train_loader)
accuracy = accuracy_score(all_labels, all_preds)
return avg_loss, accuracy
def validate_epoch(self):
"""验证一个epoch"""
self.model.eval()
total_loss = 0
all_preds = []
all_labels = []
with torch.no_grad():
for batch in self.val_loader:
input_ids = batch['input_ids'].to(self.device)
labels = batch['label'].to(self.device)
attention_mask = (input_ids != 0)
if isinstance(self.model, TransformerClassifier):
logits = self.model(input_ids, attention_mask)
else:
logits = self.model(input_ids)
loss = self.criterion(logits, labels)
total_loss += loss.item()
preds = torch.argmax(logits, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
avg_loss = total_loss / len(self.val_loader)
accuracy = accuracy_score(all_labels, all_preds)
return avg_loss, accuracy, all_preds, all_labels
def train(self, num_epochs):
"""完整训练流程"""
print(f"开始训练,共{num_epochs}个epoch")
for epoch in range(num_epochs):
# 训练
train_loss, train_acc = self.train_epoch()
# 验证
val_loss, val_acc, val_preds, val_labels = self.validate_epoch()
# 更新学习率
self.scheduler.step(val_acc)
# 记录历史
self.train_losses.append(train_loss)
self.train_accs.append(train_acc)
self.val_losses.append(val_loss)
self.val_accs.append(val_acc)
# 打印结果
print(f'Epoch {epoch+1}/{num_epochs}:')
print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
print(f' LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
# 保存最佳模型
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
torch.save(self.model.state_dict(), 'best_model.pth')
print(f' ✓ 新的最佳模型! 验证准确率: {val_acc:.4f}')
print('-' * 50)
print(f'训练完成! 最佳验证准确率: {self.best_val_acc:.4f}')
return self.train_losses, self.train_accs, self.val_losses, self.val_accs模型评估
1. 详细评估
python
def evaluate_model(model, test_loader, device, class_names):
"""详细评估模型"""
model.eval()
all_preds = []
all_labels = []
all_probs = []
with torch.no_grad():
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
labels = batch['label'].to(device)
attention_mask = (input_ids != 0)
if isinstance(model, TransformerClassifier):
logits = model(input_ids, attention_mask)
else:
logits = model(input_ids)
probs = F.softmax(logits, dim=1)
preds = torch.argmax(logits, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
# 计算指标
accuracy = accuracy_score(all_labels, all_preds)
report = classification_report(all_labels, all_preds, target_names=class_names)
cm = confusion_matrix(all_labels, all_preds)
print(f"测试准确率: {accuracy:.4f}")
print("\n分类报告:")
print(report)
return accuracy, report, cm, all_probs2. 错误分析
python
def analyze_errors(model, test_loader, preprocessor, device, class_names, num_examples=10):
"""分析错误预测"""
model.eval()
errors = []
with torch.no_grad():
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
labels = batch['label'].to(device)
texts = batch['text']
attention_mask = (input_ids != 0)
if isinstance(model, TransformerClassifier):
logits = model(input_ids, attention_mask)
else:
logits = model(input_ids)
probs = F.softmax(logits, dim=1)
preds = torch.argmax(logits, dim=1)
# 找到错误预测
wrong_mask = preds != labels
if wrong_mask.any():
wrong_indices = torch.where(wrong_mask)[0]
for idx in wrong_indices:
errors.append({
'text': texts[idx],
'true_label': class_names[labels[idx].item()],
'pred_label': class_names[preds[idx].item()],
'confidence': probs[idx].max().item(),
'all_probs': probs[idx].cpu().numpy()
})
if len(errors) >= num_examples:
break
if len(errors) >= num_examples:
break
# 打印错误分析
print("错误预测分析:")
print("=" * 80)
for i, error in enumerate(errors):
print(f"\n示例 {i+1}:")
print(f"文本: {error['text'][:200]}...")
print(f"真实标签: {error['true_label']}")
print(f"预测标签: {error['pred_label']}")
print(f"置信度: {error['confidence']:.4f}")
# 显示所有类别的概率
for j, prob in enumerate(error['all_probs']):
print(f" {class_names[j]}: {prob:.4f}")
return errors推理和应用
1. 单文本推理
python
def predict_single_text(model, text, preprocessor, device, class_names, max_length=128):
"""预测单个文本"""
model.eval()
# 预处理
sequence = preprocessor.text_to_sequence(text, max_length)
input_ids = torch.tensor([sequence], dtype=torch.long).to(device)
attention_mask = (input_ids != 0)
with torch.no_grad():
if isinstance(model, TransformerClassifier):
logits = model(input_ids, attention_mask)
else:
logits = model(input_ids)
probs = F.softmax(logits, dim=1)
pred_class = torch.argmax(logits, dim=1).item()
confidence = probs[0][pred_class].item()
# 获取所有类别的概率
all_probs = probs[0].cpu().numpy()
result = {
'text': text,
'predicted_class': class_names[pred_class],
'confidence': confidence,
'all_probabilities': {
class_names[i]: float(prob) for i, prob in enumerate(all_probs)
}
}
return result
# 使用示例
text = "I absolutely love this movie! It's fantastic!"
result = predict_single_text(model, text, preprocessor, device, ['negative', 'positive', 'neutral'])
print(f"预测结果: {result}")2. 批量推理
python
def batch_predict(model, texts, preprocessor, device, class_names, batch_size=32, max_length=128):
"""批量预测"""
model.eval()
results = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
# 预处理批量数据
sequences = [preprocessor.text_to_sequence(text, max_length) for text in batch_texts]
input_ids = torch.tensor(sequences, dtype=torch.long).to(device)
attention_mask = (input_ids != 0)
with torch.no_grad():
if isinstance(model, TransformerClassifier):
logits = model(input_ids, attention_mask)
else:
logits = model(input_ids)
probs = F.softmax(logits, dim=1)
pred_classes = torch.argmax(logits, dim=1)
# 处理结果
for j, text in enumerate(batch_texts):
pred_class = pred_classes[j].item()
confidence = probs[j][pred_class].item()
results.append({
'text': text,
'predicted_class': class_names[pred_class],
'confidence': confidence
})
return results完整训练脚本
python
def main():
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 加载数据
train_texts, train_labels, val_texts, val_labels = load_imdb_data()
# 创建预处理器
preprocessor = TextPreprocessor(language='en')
preprocessor.build_vocab(train_texts)
# 创建数据加载器
train_loader, val_loader = create_data_loaders(
train_texts, train_labels, val_texts, val_labels,
preprocessor, batch_size=32, max_length=128
)
# 创建模型
vocab_size = len(preprocessor.vocab)
num_classes = 3 # 负面、正面、中性
# 选择模型类型
model = create_lstm_model(vocab_size, num_classes)
# model = create_cnn_model(vocab_size, num_classes)
# model = create_transformer_model(vocab_size, num_classes)
print(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
# 创建训练器
trainer = TextClassificationTrainer(model, train_loader, val_loader, num_classes, device)
# 训练模型
train_losses, train_accs, val_losses, val_accs = trainer.train(num_epochs=20)
# 加载最佳模型
model.load_state_dict(torch.load('best_model.pth'))
# 评估模型
class_names = ['negative', 'positive', 'neutral']
accuracy, report, cm, probs = evaluate_model(model, val_loader, device, class_names)
# 错误分析
errors = analyze_errors(model, val_loader, preprocessor, device, class_names)
# 保存模型和预处理器
torch.save(model.state_dict(), 'final_model.pth')
preprocessor.save_vocab('vocab.pkl')
print("训练完成!")
if __name__ == '__main__':
main()总结
本章通过完整的文本分类项目展示了:
- 文本预处理:清理、分词、构建词汇表等完整流程
- 模型架构:LSTM、CNN、Transformer等不同的文本分类模型
- 训练框架:完整的训练、验证、保存流程
- 模型评估:准确率、分类报告、错误分析等评估方法
- 实际应用:单文本和批量推理的实现
这个项目模板可以适用于各种文本分类任务,如情感分析、主题分类、垃圾邮件检测等。