Skip to content

PyTorch 文本分类项目

项目概述

本章将通过一个完整的文本分类项目,展示如何使用PyTorch处理自然语言处理任务。我们将构建一个情感分析系统,能够判断文本的情感倾向(正面、负面、中性)。

项目结构

text_classification/
├── data/                   # 数据目录
│   ├── raw/               # 原始数据
│   ├── processed/         # 处理后的数据
│   └── vocab/             # 词汇表
├── models/                # 模型定义
│   ├── __init__.py
│   ├── lstm_classifier.py
│   ├── transformer_classifier.py
│   └── cnn_classifier.py
├── utils/                 # 工具函数
│   ├── __init__.py
│   ├── data_loader.py
│   ├── text_processor.py
│   └── metrics.py
├── configs/               # 配置文件
├── train.py              # 训练脚本
├── evaluate.py           # 评估脚本
└── inference.py          # 推理脚本

数据预处理

1. 文本预处理器

python
import re
import string
import torch
from collections import Counter, defaultdict
from typing import List, Dict, Tuple
import jieba  # 中文分词
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

class TextPreprocessor:
    def __init__(self, language='en', max_vocab_size=50000, min_freq=2):
        self.language = language
        self.max_vocab_size = max_vocab_size
        self.min_freq = min_freq
        
        # 特殊标记
        self.PAD_TOKEN = '<PAD>'
        self.UNK_TOKEN = '<UNK>'
        self.SOS_TOKEN = '<SOS>'
        self.EOS_TOKEN = '<EOS>'
        
        # 词汇表
        self.vocab = {}
        self.idx2word = {}
        self.word_freq = Counter()
        
        # 停用词
        if language == 'en':
            try:
                self.stop_words = set(stopwords.words('english'))
            except:
                self.stop_words = set()
        else:
            self.stop_words = set()
    
    def clean_text(self, text: str) -> str:
        """清理文本"""
        # 转换为小写
        text = text.lower()
        
        # 移除HTML标签
        text = re.sub(r'<[^>]+>', '', text)
        
        # 移除URL
        text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
        
        # 移除邮箱
        text = re.sub(r'\S+@\S+', '', text)
        
        # 移除数字(可选)
        # text = re.sub(r'\d+', '', text)
        
        # 移除多余的空白字符
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text
    
    def tokenize(self, text: str) -> List[str]:
        """分词"""
        if self.language == 'zh':
            # 中文分词
            tokens = list(jieba.cut(text))
        else:
            # 英文分词
            tokens = word_tokenize(text)
        
        # 移除标点符号和停用词
        tokens = [
            token for token in tokens 
            if token not in string.punctuation and token not in self.stop_words
        ]
        
        return tokens
    
    def build_vocab(self, texts: List[str]):
        """构建词汇表"""
        print("构建词汇表...")
        
        # 统计词频
        for text in texts:
            cleaned_text = self.clean_text(text)
            tokens = self.tokenize(cleaned_text)
            self.word_freq.update(tokens)
        
        # 创建词汇表
        vocab_list = [self.PAD_TOKEN, self.UNK_TOKEN, self.SOS_TOKEN, self.EOS_TOKEN]
        
        # 按频率排序,取前max_vocab_size个词
        sorted_words = sorted(self.word_freq.items(), key=lambda x: x[1], reverse=True)
        for word, freq in sorted_words:
            if freq >= self.min_freq and len(vocab_list) < self.max_vocab_size:
                vocab_list.append(word)
        
        # 构建词汇表映射
        self.vocab = {word: idx for idx, word in enumerate(vocab_list)}
        self.idx2word = {idx: word for word, idx in self.vocab.items()}
        
        print(f"词汇表大小: {len(self.vocab)}")
        print(f"总词频: {sum(self.word_freq.values())}")
    
    def text_to_sequence(self, text: str, max_length: int = None) -> List[int]:
        """将文本转换为序列"""
        cleaned_text = self.clean_text(text)
        tokens = self.tokenize(cleaned_text)
        
        # 转换为索引
        sequence = [
            self.vocab.get(token, self.vocab[self.UNK_TOKEN]) 
            for token in tokens
        ]
        
        # 截断或填充
        if max_length:
            if len(sequence) > max_length:
                sequence = sequence[:max_length]
            else:
                sequence.extend([self.vocab[self.PAD_TOKEN]] * (max_length - len(sequence)))
        
        return sequence
    
    def sequence_to_text(self, sequence: List[int]) -> str:
        """将序列转换为文本"""
        tokens = [
            self.idx2word.get(idx, self.UNK_TOKEN) 
            for idx in sequence
            if idx != self.vocab[self.PAD_TOKEN]
        ]
        return ' '.join(tokens)
    
    def save_vocab(self, filepath: str):
        """保存词汇表"""
        import pickle
        vocab_data = {
            'vocab': self.vocab,
            'idx2word': self.idx2word,
            'word_freq': self.word_freq,
            'config': {
                'language': self.language,
                'max_vocab_size': self.max_vocab_size,
                'min_freq': self.min_freq
            }
        }
        with open(filepath, 'wb') as f:
            pickle.dump(vocab_data, f)
    
    def load_vocab(self, filepath: str):
        """加载词汇表"""
        import pickle
        with open(filepath, 'rb') as f:
            vocab_data = pickle.load(f)
        
        self.vocab = vocab_data['vocab']
        self.idx2word = vocab_data['idx2word']
        self.word_freq = vocab_data['word_freq']
        config = vocab_data['config']
        self.language = config['language']
        self.max_vocab_size = config['max_vocab_size']
        self.min_freq = config['min_freq']

# 使用示例
preprocessor = TextPreprocessor(language='en')

# 示例文本
texts = [
    "I love this movie! It's absolutely fantastic.",
    "This film is terrible. I hate it.",
    "The movie is okay, not great but not bad either."
]

# 构建词汇表
preprocessor.build_vocab(texts)

# 文本转序列
sequence = preprocessor.text_to_sequence(texts[0], max_length=20)
print(f"原文: {texts[0]}")
print(f"序列: {sequence}")
print(f"还原: {preprocessor.sequence_to_text(sequence)}")

2. 数据集类

python
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch

class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, preprocessor, max_length=128):
        self.texts = texts
        self.labels = labels
        self.preprocessor = preprocessor
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # 文本转序列
        sequence = self.preprocessor.text_to_sequence(text, self.max_length)
        
        return {
            'input_ids': torch.tensor(sequence, dtype=torch.long),
            'label': torch.tensor(label, dtype=torch.long),
            'text': text
        }

def create_data_loaders(train_texts, train_labels, val_texts, val_labels, 
                       preprocessor, batch_size=32, max_length=128):
    """创建数据加载器"""
    
    train_dataset = TextClassificationDataset(
        train_texts, train_labels, preprocessor, max_length
    )
    val_dataset = TextClassificationDataset(
        val_texts, val_labels, preprocessor, max_length
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, 
        num_workers=4, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True
    )
    
    return train_loader, val_loader

# 示例数据加载
def load_imdb_data():
    """加载IMDB数据集示例"""
    # 这里使用示例数据,实际项目中需要加载真实数据
    train_texts = [
        "I love this movie! It's absolutely fantastic.",
        "This film is terrible. I hate it.",
        "The movie is okay, not great but not bad either.",
        "Amazing cinematography and great acting!",
        "Boring and predictable plot."
    ] * 1000  # 扩展数据
    
    train_labels = [1, 0, 2, 1, 0] * 1000  # 0: 负面, 1: 正面, 2: 中性
    
    val_texts = train_texts[:500]
    val_labels = train_labels[:500]
    
    return train_texts, train_labels, val_texts, val_labels

模型架构

1. LSTM分类器

python
import torch.nn as nn
import torch.nn.functional as F

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes, 
                 num_layers=2, dropout=0.3, bidirectional=True):
        super(LSTMClassifier, self).__init__()
        
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        
        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # LSTM层
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim, num_layers,
            batch_first=True, dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )
        
        # 注意力机制
        lstm_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
        self.attention = nn.Linear(lstm_output_dim, 1)
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(lstm_output_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, input_ids, attention_mask=None):
        # 词嵌入
        embedded = self.embedding(input_ids)  # (batch_size, seq_len, embed_dim)
        
        # LSTM
        lstm_out, (hidden, cell) = self.lstm(embedded)  # (batch_size, seq_len, hidden_dim*2)
        
        # 注意力机制
        if attention_mask is not None:
            # 创建注意力掩码
            attention_weights = self.attention(lstm_out).squeeze(-1)  # (batch_size, seq_len)
            attention_weights = attention_weights.masked_fill(attention_mask == 0, -1e9)
            attention_weights = F.softmax(attention_weights, dim=1)
            
            # 加权平均
            context = torch.sum(attention_weights.unsqueeze(-1) * lstm_out, dim=1)
        else:
            # 简单平均池化
            context = torch.mean(lstm_out, dim=1)
        
        # 分类
        logits = self.classifier(context)
        
        return logits

# 创建模型
def create_lstm_model(vocab_size, num_classes):
    model = LSTMClassifier(
        vocab_size=vocab_size,
        embed_dim=128,
        hidden_dim=256,
        num_classes=num_classes,
        num_layers=2,
        dropout=0.3,
        bidirectional=True
    )
    return model

2. CNN分类器

python
class CNNClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, 
                 num_classes, dropout=0.3):
        super(CNNClassifier, self).__init__()
        
        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # 多个卷积层
        self.convs = nn.ModuleList([
            nn.Conv1d(embed_dim, num_filters, kernel_size=fs)
            for fs in filter_sizes
        ])
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(len(filter_sizes) * num_filters, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, input_ids):
        # 词嵌入
        embedded = self.embedding(input_ids)  # (batch_size, seq_len, embed_dim)
        embedded = embedded.transpose(1, 2)   # (batch_size, embed_dim, seq_len)
        
        # 卷积和池化
        conv_outputs = []
        for conv in self.convs:
            conv_out = F.relu(conv(embedded))  # (batch_size, num_filters, conv_seq_len)
            pooled = F.max_pool1d(conv_out, conv_out.size(2))  # (batch_size, num_filters, 1)
            conv_outputs.append(pooled.squeeze(2))
        
        # 拼接所有卷积输出
        concat_output = torch.cat(conv_outputs, dim=1)  # (batch_size, len(filter_sizes) * num_filters)
        
        # 分类
        logits = self.classifier(concat_output)
        
        return logits

def create_cnn_model(vocab_size, num_classes):
    model = CNNClassifier(
        vocab_size=vocab_size,
        embed_dim=128,
        num_filters=100,
        filter_sizes=[3, 4, 5],
        num_classes=num_classes,
        dropout=0.3
    )
    return model

3. Transformer分类器

python
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, 
                 num_classes, max_length=512, dropout=0.1):
        super(TransformerClassifier, self).__init__()
        
        self.embed_dim = embed_dim
        self.max_length = max_length
        
        # 词嵌入和位置编码
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_encoding = nn.Parameter(torch.randn(max_length, embed_dim))
        
        # Transformer编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, num_classes)
        )
    
    def forward(self, input_ids, attention_mask=None):
        seq_len = input_ids.size(1)
        
        # 词嵌入 + 位置编码
        embedded = self.embedding(input_ids)
        embedded += self.pos_encoding[:seq_len, :].unsqueeze(0)
        
        # 创建填充掩码
        if attention_mask is None:
            attention_mask = (input_ids != 0)
        
        # Transformer编码
        transformer_out = self.transformer(
            embedded, 
            src_key_padding_mask=~attention_mask
        )
        
        # 全局平均池化(忽略填充位置)
        mask_expanded = attention_mask.unsqueeze(-1).float()
        sum_embeddings = torch.sum(transformer_out * mask_expanded, dim=1)
        sum_mask = torch.sum(mask_expanded, dim=1)
        pooled = sum_embeddings / sum_mask
        
        # 分类
        logits = self.classifier(pooled)
        
        return logits

def create_transformer_model(vocab_size, num_classes):
    model = TransformerClassifier(
        vocab_size=vocab_size,
        embed_dim=256,
        num_heads=8,
        num_layers=6,
        num_classes=num_classes,
        max_length=512,
        dropout=0.1
    )
    return model

训练框架

1. 训练器

python
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np

class TextClassificationTrainer:
    def __init__(self, model, train_loader, val_loader, num_classes, device):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.num_classes = num_classes
        self.device = device
        
        # 损失函数和优化器
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', patience=3, factor=0.5)
        
        # 训练历史
        self.train_losses = []
        self.train_accs = []
        self.val_losses = []
        self.val_accs = []
        
        # 最佳模型
        self.best_val_acc = 0.0
    
    def train_epoch(self):
        """训练一个epoch"""
        self.model.train()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        for batch in self.train_loader:
            input_ids = batch['input_ids'].to(self.device)
            labels = batch['label'].to(self.device)
            
            # 创建注意力掩码
            attention_mask = (input_ids != 0)
            
            self.optimizer.zero_grad()
            
            # 前向传播
            if isinstance(self.model, TransformerClassifier):
                logits = self.model(input_ids, attention_mask)
            else:
                logits = self.model(input_ids)
            
            loss = self.criterion(logits, labels)
            
            # 反向传播
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            
            # 统计
            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        
        avg_loss = total_loss / len(self.train_loader)
        accuracy = accuracy_score(all_labels, all_preds)
        
        return avg_loss, accuracy
    
    def validate_epoch(self):
        """验证一个epoch"""
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch in self.val_loader:
                input_ids = batch['input_ids'].to(self.device)
                labels = batch['label'].to(self.device)
                
                attention_mask = (input_ids != 0)
                
                if isinstance(self.model, TransformerClassifier):
                    logits = self.model(input_ids, attention_mask)
                else:
                    logits = self.model(input_ids)
                
                loss = self.criterion(logits, labels)
                
                total_loss += loss.item()
                preds = torch.argmax(logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        avg_loss = total_loss / len(self.val_loader)
        accuracy = accuracy_score(all_labels, all_preds)
        
        return avg_loss, accuracy, all_preds, all_labels
    
    def train(self, num_epochs):
        """完整训练流程"""
        print(f"开始训练,共{num_epochs}个epoch")
        
        for epoch in range(num_epochs):
            # 训练
            train_loss, train_acc = self.train_epoch()
            
            # 验证
            val_loss, val_acc, val_preds, val_labels = self.validate_epoch()
            
            # 更新学习率
            self.scheduler.step(val_acc)
            
            # 记录历史
            self.train_losses.append(train_loss)
            self.train_accs.append(train_acc)
            self.val_losses.append(val_loss)
            self.val_accs.append(val_acc)
            
            # 打印结果
            print(f'Epoch {epoch+1}/{num_epochs}:')
            print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
            print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
            print(f'  LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
            
            # 保存最佳模型
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                torch.save(self.model.state_dict(), 'best_model.pth')
                print(f'  ✓ 新的最佳模型! 验证准确率: {val_acc:.4f}')
            
            print('-' * 50)
        
        print(f'训练完成! 最佳验证准确率: {self.best_val_acc:.4f}')
        
        return self.train_losses, self.train_accs, self.val_losses, self.val_accs

模型评估

1. 详细评估

python
def evaluate_model(model, test_loader, device, class_names):
    """详细评估模型"""
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['label'].to(device)
            
            attention_mask = (input_ids != 0)
            
            if isinstance(model, TransformerClassifier):
                logits = model(input_ids, attention_mask)
            else:
                logits = model(input_ids)
            
            probs = F.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # 计算指标
    accuracy = accuracy_score(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, target_names=class_names)
    cm = confusion_matrix(all_labels, all_preds)
    
    print(f"测试准确率: {accuracy:.4f}")
    print("\n分类报告:")
    print(report)
    
    return accuracy, report, cm, all_probs

2. 错误分析

python
def analyze_errors(model, test_loader, preprocessor, device, class_names, num_examples=10):
    """分析错误预测"""
    model.eval()
    errors = []
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['label'].to(device)
            texts = batch['text']
            
            attention_mask = (input_ids != 0)
            
            if isinstance(model, TransformerClassifier):
                logits = model(input_ids, attention_mask)
            else:
                logits = model(input_ids)
            
            probs = F.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            # 找到错误预测
            wrong_mask = preds != labels
            if wrong_mask.any():
                wrong_indices = torch.where(wrong_mask)[0]
                
                for idx in wrong_indices:
                    errors.append({
                        'text': texts[idx],
                        'true_label': class_names[labels[idx].item()],
                        'pred_label': class_names[preds[idx].item()],
                        'confidence': probs[idx].max().item(),
                        'all_probs': probs[idx].cpu().numpy()
                    })
                    
                    if len(errors) >= num_examples:
                        break
            
            if len(errors) >= num_examples:
                break
    
    # 打印错误分析
    print("错误预测分析:")
    print("=" * 80)
    
    for i, error in enumerate(errors):
        print(f"\n示例 {i+1}:")
        print(f"文本: {error['text'][:200]}...")
        print(f"真实标签: {error['true_label']}")
        print(f"预测标签: {error['pred_label']}")
        print(f"置信度: {error['confidence']:.4f}")
        
        # 显示所有类别的概率
        for j, prob in enumerate(error['all_probs']):
            print(f"  {class_names[j]}: {prob:.4f}")
    
    return errors

推理和应用

1. 单文本推理

python
def predict_single_text(model, text, preprocessor, device, class_names, max_length=128):
    """预测单个文本"""
    model.eval()
    
    # 预处理
    sequence = preprocessor.text_to_sequence(text, max_length)
    input_ids = torch.tensor([sequence], dtype=torch.long).to(device)
    attention_mask = (input_ids != 0)
    
    with torch.no_grad():
        if isinstance(model, TransformerClassifier):
            logits = model(input_ids, attention_mask)
        else:
            logits = model(input_ids)
        
        probs = F.softmax(logits, dim=1)
        pred_class = torch.argmax(logits, dim=1).item()
        confidence = probs[0][pred_class].item()
    
    # 获取所有类别的概率
    all_probs = probs[0].cpu().numpy()
    
    result = {
        'text': text,
        'predicted_class': class_names[pred_class],
        'confidence': confidence,
        'all_probabilities': {
            class_names[i]: float(prob) for i, prob in enumerate(all_probs)
        }
    }
    
    return result

# 使用示例
text = "I absolutely love this movie! It's fantastic!"
result = predict_single_text(model, text, preprocessor, device, ['negative', 'positive', 'neutral'])
print(f"预测结果: {result}")

2. 批量推理

python
def batch_predict(model, texts, preprocessor, device, class_names, batch_size=32, max_length=128):
    """批量预测"""
    model.eval()
    results = []
    
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        
        # 预处理批量数据
        sequences = [preprocessor.text_to_sequence(text, max_length) for text in batch_texts]
        input_ids = torch.tensor(sequences, dtype=torch.long).to(device)
        attention_mask = (input_ids != 0)
        
        with torch.no_grad():
            if isinstance(model, TransformerClassifier):
                logits = model(input_ids, attention_mask)
            else:
                logits = model(input_ids)
            
            probs = F.softmax(logits, dim=1)
            pred_classes = torch.argmax(logits, dim=1)
        
        # 处理结果
        for j, text in enumerate(batch_texts):
            pred_class = pred_classes[j].item()
            confidence = probs[j][pred_class].item()
            
            results.append({
                'text': text,
                'predicted_class': class_names[pred_class],
                'confidence': confidence
            })
    
    return results

完整训练脚本

python
def main():
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # 加载数据
    train_texts, train_labels, val_texts, val_labels = load_imdb_data()
    
    # 创建预处理器
    preprocessor = TextPreprocessor(language='en')
    preprocessor.build_vocab(train_texts)
    
    # 创建数据加载器
    train_loader, val_loader = create_data_loaders(
        train_texts, train_labels, val_texts, val_labels,
        preprocessor, batch_size=32, max_length=128
    )
    
    # 创建模型
    vocab_size = len(preprocessor.vocab)
    num_classes = 3  # 负面、正面、中性
    
    # 选择模型类型
    model = create_lstm_model(vocab_size, num_classes)
    # model = create_cnn_model(vocab_size, num_classes)
    # model = create_transformer_model(vocab_size, num_classes)
    
    print(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
    
    # 创建训练器
    trainer = TextClassificationTrainer(model, train_loader, val_loader, num_classes, device)
    
    # 训练模型
    train_losses, train_accs, val_losses, val_accs = trainer.train(num_epochs=20)
    
    # 加载最佳模型
    model.load_state_dict(torch.load('best_model.pth'))
    
    # 评估模型
    class_names = ['negative', 'positive', 'neutral']
    accuracy, report, cm, probs = evaluate_model(model, val_loader, device, class_names)
    
    # 错误分析
    errors = analyze_errors(model, val_loader, preprocessor, device, class_names)
    
    # 保存模型和预处理器
    torch.save(model.state_dict(), 'final_model.pth')
    preprocessor.save_vocab('vocab.pkl')
    
    print("训练完成!")

if __name__ == '__main__':
    main()

总结

本章通过完整的文本分类项目展示了:

  1. 文本预处理:清理、分词、构建词汇表等完整流程
  2. 模型架构:LSTM、CNN、Transformer等不同的文本分类模型
  3. 训练框架:完整的训练、验证、保存流程
  4. 模型评估:准确率、分类报告、错误分析等评估方法
  5. 实际应用:单文本和批量推理的实现

这个项目模板可以适用于各种文本分类任务,如情感分析、主题分类、垃圾邮件检测等。

本站内容仅供学习和研究使用。