Skip to content

图像分类项目

本章将通过一个完整的图像分类项目,展示如何使用TensorFlow构建、训练和部署一个实用的深度学习模型。我们将从数据准备开始,逐步完成整个机器学习流程。

项目概述

我们将构建一个能够识别不同动物的图像分类器,使用CIFAR-10数据集作为示例,然后扩展到自定义数据集。

项目目标

  • 构建高精度的图像分类模型
  • 学习数据预处理和增强技术
  • 掌握模型训练和调优方法
  • 实现模型评估和可视化
  • 部署模型进行实际应用
python
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import os
import cv2
from pathlib import Path

# 设置随机种子
tf.random.set_seed(42)
np.random.seed(42)

print(f"TensorFlow版本: {tf.__version__}")
print(f"GPU可用: {tf.config.list_physical_devices('GPU')}")

数据准备

加载CIFAR-10数据集

python
def load_cifar10_data():
    """
    加载和预处理CIFAR-10数据集
    """
    # 加载数据
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
    
    # 类别名称
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                   'dog', 'frog', 'horse', 'ship', 'truck']
    
    # 数据信息
    print(f"训练集形状: {x_train.shape}")
    print(f"测试集形状: {x_test.shape}")
    print(f"类别数量: {len(class_names)}")
    
    return (x_train, y_train), (x_test, y_test), class_names

def preprocess_data(x_train, y_train, x_test, y_test):
    """
    数据预处理
    """
    # 归一化像素值到[0,1]
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0
    
    # 转换标签为分类格式
    num_classes = len(np.unique(y_train))
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)
    
    return x_train, y_train, x_test, y_test, num_classes

# 加载和预处理数据
(x_train, y_train), (x_test, y_test), class_names = load_cifar10_data()
x_train, y_train, x_test, y_test, num_classes = preprocess_data(
    x_train, y_train, x_test, y_test
)

数据可视化

python
def visualize_dataset(x_train, y_train, class_names, num_samples=25):
    """
    可视化数据集样本
    """
    plt.figure(figsize=(12, 12))
    
    for i in range(num_samples):
        plt.subplot(5, 5, i + 1)
        plt.imshow(x_train[i])
        plt.title(f'{class_names[np.argmax(y_train[i])]}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

def plot_class_distribution(y_train, class_names):
    """
    绘制类别分布
    """
    class_counts = np.sum(y_train, axis=0)
    
    plt.figure(figsize=(12, 6))
    bars = plt.bar(class_names, class_counts)
    plt.title('训练集类别分布')
    plt.xlabel('类别')
    plt.ylabel('样本数量')
    plt.xticks(rotation=45)
    
    # 添加数值标签
    for bar, count in zip(bars, class_counts):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 10,
                f'{int(count)}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

# 可视化数据
visualize_dataset(x_train, y_train, class_names)
plot_class_distribution(y_train, class_names)

数据增强

python
def create_data_augmentation():
    """
    创建数据增强管道
    """
    data_augmentation = keras.Sequential([
        keras.layers.RandomFlip("horizontal"),
        keras.layers.RandomRotation(0.1),
        keras.layers.RandomZoom(0.1),
        keras.layers.RandomContrast(0.1),
        keras.layers.RandomBrightness(0.1),
    ])
    
    return data_augmentation

def visualize_augmentation(x_train, data_augmentation):
    """
    可视化数据增强效果
    """
    # 选择一个样本
    sample_image = x_train[0:1]
    
    plt.figure(figsize=(15, 5))
    
    # 原始图像
    plt.subplot(1, 6, 1)
    plt.imshow(sample_image[0])
    plt.title('原始图像')
    plt.axis('off')
    
    # 增强后的图像
    for i in range(5):
        augmented_image = data_augmentation(sample_image, training=True)
        plt.subplot(1, 6, i + 2)
        plt.imshow(augmented_image[0])
        plt.title(f'增强 {i+1}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# 创建数据增强
data_augmentation = create_data_augmentation()
visualize_augmentation(x_train, data_augmentation)

模型构建

基础CNN模型

python
def create_basic_cnn(input_shape, num_classes):
    """
    创建基础CNN模型
    """
    model = keras.Sequential([
        # 数据增强层
        data_augmentation,
        
        # 第一个卷积块
        keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2D(32, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Dropout(0.25),
        
        # 第二个卷积块
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Dropout(0.25),
        
        # 第三个卷积块
        keras.layers.Conv2D(128, (3, 3), activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2D(128, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Dropout(0.25),
        
        # 全连接层
        keras.layers.Flatten(),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(num_classes, activation='softmax')
    ])
    
    return model

# 创建基础模型
basic_model = create_basic_cnn((32, 32, 3), num_classes)
basic_model.summary()

ResNet风格模型

python
def residual_block(x, filters, kernel_size=3, stride=1, conv_shortcut=False):
    """
    残差块
    """
    if conv_shortcut:
        shortcut = keras.layers.Conv2D(filters, 1, strides=stride)(x)
        shortcut = keras.layers.BatchNormalization()(shortcut)
    else:
        shortcut = x
    
    x = keras.layers.Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)
    
    x = keras.layers.Conv2D(filters, kernel_size, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    
    x = keras.layers.Add()([shortcut, x])
    x = keras.layers.ReLU()(x)
    
    return x

def create_resnet_model(input_shape, num_classes):
    """
    创建ResNet风格模型
    """
    inputs = keras.layers.Input(shape=input_shape)
    
    # 数据增强
    x = data_augmentation(inputs)
    
    # 初始卷积
    x = keras.layers.Conv2D(64, 7, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.MaxPooling2D(3, strides=2, padding='same')(x)
    
    # 残差块组
    x = residual_block(x, 64, conv_shortcut=True)
    x = residual_block(x, 64)
    
    x = residual_block(x, 128, stride=2, conv_shortcut=True)
    x = residual_block(x, 128)
    
    x = residual_block(x, 256, stride=2, conv_shortcut=True)
    x = residual_block(x, 256)
    
    # 全局平均池化
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(0.5)(x)
    
    # 分类层
    outputs = keras.layers.Dense(num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs)
    return model

# 创建ResNet模型
resnet_model = create_resnet_model((32, 32, 3), num_classes)
resnet_model.summary()

使用预训练模型

python
def create_transfer_learning_model(input_shape, num_classes, base_model_name='EfficientNetB0'):
    """
    创建迁移学习模型
    """
    # 加载预训练模型
    if base_model_name == 'EfficientNetB0':
        base_model = keras.applications.EfficientNetB0(
            weights='imagenet',
            include_top=False,
            input_shape=input_shape
        )
    elif base_model_name == 'ResNet50':
        base_model = keras.applications.ResNet50(
            weights='imagenet',
            include_top=False,
            input_shape=input_shape
        )
    else:
        raise ValueError(f"不支持的模型: {base_model_name}")
    
    # 冻结预训练层
    base_model.trainable = False
    
    # 构建完整模型
    inputs = keras.layers.Input(shape=input_shape)
    
    # 数据增强
    x = data_augmentation(inputs)
    
    # 预处理(调整到预训练模型期望的输入)
    x = keras.layers.Resizing(224, 224)(x)
    x = keras.applications.efficientnet.preprocess_input(x)
    
    # 预训练模型
    x = base_model(x, training=False)
    
    # 自定义头部
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(0.2)(x)
    outputs = keras.layers.Dense(num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs)
    return model, base_model

# 创建迁移学习模型
transfer_model, base_model = create_transfer_learning_model((32, 32, 3), num_classes)
transfer_model.summary()

模型训练

训练配置

python
def compile_model(model, learning_rate=0.001):
    """
    编译模型
    """
    optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
    
    model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy', 'top_k_categorical_accuracy']
    )
    
    return model

def create_callbacks(model_name):
    """
    创建训练回调
    """
    callbacks = [
        # 模型检查点
        keras.callbacks.ModelCheckpoint(
            f'best_{model_name}.h5',
            monitor='val_accuracy',
            save_best_only=True,
            save_weights_only=False,
            verbose=1
        ),
        
        # 早停
        keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=10,
            restore_best_weights=True,
            verbose=1
        ),
        
        # 学习率调度
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.2,
            patience=5,
            min_lr=1e-7,
            verbose=1
        ),
        
        # TensorBoard
        keras.callbacks.TensorBoard(
            log_dir=f'logs/{model_name}',
            histogram_freq=1,
            write_graph=True,
            write_images=True
        )
    ]
    
    return callbacks

# 编译模型
basic_model = compile_model(basic_model)
callbacks = create_callbacks('basic_cnn')

训练过程

python
def train_model(model, x_train, y_train, x_test, y_test, 
                callbacks, epochs=100, batch_size=32, validation_split=0.2):
    """
    训练模型
    """
    history = model.fit(
        x_train, y_train,
        batch_size=batch_size,
        epochs=epochs,
        validation_split=validation_split,
        callbacks=callbacks,
        verbose=1
    )
    
    return history

def plot_training_history(history):
    """
    绘制训练历史
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 损失
    axes[0, 0].plot(history.history['loss'], label='训练损失')
    axes[0, 0].plot(history.history['val_loss'], label='验证损失')
    axes[0, 0].set_title('模型损失')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('损失')
    axes[0, 0].legend()
    
    # 准确率
    axes[0, 1].plot(history.history['accuracy'], label='训练准确率')
    axes[0, 1].plot(history.history['val_accuracy'], label='验证准确率')
    axes[0, 1].set_title('模型准确率')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('准确率')
    axes[0, 1].legend()
    
    # Top-K准确率
    axes[1, 0].plot(history.history['top_k_categorical_accuracy'], label='训练Top-K准确率')
    axes[1, 0].plot(history.history['val_top_k_categorical_accuracy'], label='验证Top-K准确率')
    axes[1, 0].set_title('Top-K准确率')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('准确率')
    axes[1, 0].legend()
    
    # 学习率(如果有记录)
    if 'lr' in history.history:
        axes[1, 1].plot(history.history['lr'])
        axes[1, 1].set_title('学习率')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('学习率')
        axes[1, 1].set_yscale('log')
    
    plt.tight_layout()
    plt.show()

# 训练模型
print("开始训练基础CNN模型...")
history = train_model(basic_model, x_train, y_train, x_test, y_test, callbacks)
plot_training_history(history)

模型评估

性能评估

python
def evaluate_model(model, x_test, y_test, class_names):
    """
    全面评估模型性能
    """
    # 预测
    y_pred = model.predict(x_test)
    y_pred_classes = np.argmax(y_pred, axis=1)
    y_true_classes = np.argmax(y_test, axis=1)
    
    # 计算准确率
    test_loss, test_accuracy, test_top_k = model.evaluate(x_test, y_test, verbose=0)
    print(f"测试损失: {test_loss:.4f}")
    print(f"测试准确率: {test_accuracy:.4f}")
    print(f"Top-K准确率: {test_top_k:.4f}")
    
    # 分类报告
    print("\n分类报告:")
    print(classification_report(y_true_classes, y_pred_classes, 
                              target_names=class_names))
    
    return y_pred, y_pred_classes, y_true_classes

def plot_confusion_matrix(y_true, y_pred, class_names):
    """
    绘制混淆矩阵
    """
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('混淆矩阵')
    plt.xlabel('预测类别')
    plt.ylabel('真实类别')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    
    # 计算每个类别的准确率
    class_accuracy = cm.diagonal() / cm.sum(axis=1)
    
    plt.figure(figsize=(12, 6))
    bars = plt.bar(class_names, class_accuracy)
    plt.title('各类别准确率')
    plt.xlabel('类别')
    plt.ylabel('准确率')
    plt.xticks(rotation=45)
    
    # 添加数值标签
    for bar, acc in zip(bars, class_accuracy):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{acc:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

# 评估模型
y_pred, y_pred_classes, y_true_classes = evaluate_model(
    basic_model, x_test, y_test, class_names
)
plot_confusion_matrix(y_true_classes, y_pred_classes, class_names)

错误分析

python
def analyze_errors(x_test, y_true, y_pred, y_pred_classes, class_names, num_examples=20):
    """
    分析预测错误的样本
    """
    # 找出错误预测的样本
    incorrect_indices = np.where(y_true != y_pred_classes)[0]
    
    # 随机选择一些错误样本
    if len(incorrect_indices) > num_examples:
        selected_indices = np.random.choice(incorrect_indices, num_examples, replace=False)
    else:
        selected_indices = incorrect_indices
    
    # 可视化错误样本
    cols = 5
    rows = (len(selected_indices) + cols - 1) // cols
    
    plt.figure(figsize=(15, 3 * rows))
    
    for i, idx in enumerate(selected_indices):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(x_test[idx])
        
        true_label = class_names[y_true[idx]]
        pred_label = class_names[y_pred_classes[idx]]
        confidence = np.max(y_pred[idx])
        
        plt.title(f'真实: {true_label}\n预测: {pred_label}\n置信度: {confidence:.3f}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

def plot_prediction_confidence(y_pred, y_true, y_pred_classes):
    """
    分析预测置信度分布
    """
    # 计算预测置信度
    confidences = np.max(y_pred, axis=1)
    
    # 正确和错误预测的置信度
    correct_mask = (y_true == y_pred_classes)
    correct_confidences = confidences[correct_mask]
    incorrect_confidences = confidences[~correct_mask]
    
    plt.figure(figsize=(12, 5))
    
    # 置信度分布
    plt.subplot(1, 2, 1)
    plt.hist(correct_confidences, bins=50, alpha=0.7, label='正确预测', color='green')
    plt.hist(incorrect_confidences, bins=50, alpha=0.7, label='错误预测', color='red')
    plt.xlabel('预测置信度')
    plt.ylabel('频次')
    plt.title('预测置信度分布')
    plt.legend()
    
    # 置信度vs准确率
    plt.subplot(1, 2, 2)
    confidence_bins = np.linspace(0, 1, 11)
    bin_accuracies = []
    bin_counts = []
    
    for i in range(len(confidence_bins) - 1):
        mask = (confidences >= confidence_bins[i]) & (confidences < confidence_bins[i + 1])
        if np.sum(mask) > 0:
            accuracy = np.mean(correct_mask[mask])
            bin_accuracies.append(accuracy)
            bin_counts.append(np.sum(mask))
        else:
            bin_accuracies.append(0)
            bin_counts.append(0)
    
    bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2
    plt.bar(bin_centers, bin_accuracies, width=0.08, alpha=0.7)
    plt.xlabel('置信度区间')
    plt.ylabel('准确率')
    plt.title('置信度vs准确率')
    plt.ylim(0, 1)
    
    plt.tight_layout()
    plt.show()

# 错误分析
analyze_errors(x_test, y_true_classes, y_pred, y_pred_classes, class_names)
plot_prediction_confidence(y_pred, y_true_classes, y_pred_classes)

模型优化

超参数调优

python
def hyperparameter_tuning():
    """
    超参数调优示例
    """
    import keras_tuner as kt
    
    def build_model(hp):
        model = keras.Sequential()
        
        # 数据增强
        model.add(data_augmentation)
        
        # 卷积层数量和参数调优
        for i in range(hp.Int('num_conv_blocks', 2, 4)):
            model.add(keras.layers.Conv2D(
                filters=hp.Int(f'conv_{i}_filters', 32, 256, step=32),
                kernel_size=hp.Choice(f'conv_{i}_kernel', [3, 5]),
                activation='relu',
                padding='same'
            ))
            model.add(keras.layers.BatchNormalization())
            
            if hp.Boolean(f'conv_{i}_dropout'):
                model.add(keras.layers.Dropout(hp.Float(f'conv_{i}_dropout_rate', 0.1, 0.5)))
            
            model.add(keras.layers.MaxPooling2D(2))
        
        # 全连接层
        model.add(keras.layers.Flatten())
        
        for i in range(hp.Int('num_dense_layers', 1, 3)):
            model.add(keras.layers.Dense(
                units=hp.Int(f'dense_{i}_units', 128, 1024, step=128),
                activation='relu'
            ))
            model.add(keras.layers.Dropout(hp.Float(f'dense_{i}_dropout', 0.2, 0.7)))
        
        model.add(keras.layers.Dense(num_classes, activation='softmax'))
        
        # 编译模型
        model.compile(
            optimizer=keras.optimizers.Adam(hp.Float('learning_rate', 1e-4, 1e-2, sampling='log')),
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )
        
        return model
    
    # 创建调优器
    tuner = kt.RandomSearch(
        build_model,
        objective='val_accuracy',
        max_trials=20,
        directory='hyperparameter_tuning',
        project_name='cifar10_classification'
    )
    
    # 搜索最佳超参数
    tuner.search(x_train, y_train,
                epochs=10,
                validation_split=0.2,
                verbose=1)
    
    # 获取最佳模型
    best_model = tuner.get_best_models(num_models=1)[0]
    best_hyperparameters = tuner.get_best_hyperparameters(num_trials=1)[0]
    
    return best_model, best_hyperparameters

# 注意:实际运行需要安装keras-tuner
# pip install keras-tuner

模型集成

python
def create_ensemble_model(models, x_test, y_test):
    """
    创建模型集成
    """
    predictions = []
    
    for model in models:
        pred = model.predict(x_test)
        predictions.append(pred)
    
    # 平均集成
    ensemble_pred = np.mean(predictions, axis=0)
    ensemble_classes = np.argmax(ensemble_pred, axis=1)
    
    # 投票集成
    individual_classes = [np.argmax(pred, axis=1) for pred in predictions]
    voting_pred = np.array([np.bincount(votes).argmax() 
                           for votes in zip(*individual_classes)])
    
    # 评估集成性能
    y_true = np.argmax(y_test, axis=1)
    
    ensemble_accuracy = np.mean(ensemble_classes == y_true)
    voting_accuracy = np.mean(voting_pred == y_true)
    
    print(f"平均集成准确率: {ensemble_accuracy:.4f}")
    print(f"投票集成准确率: {voting_accuracy:.4f}")
    
    return ensemble_pred, voting_pred

# 如果有多个训练好的模型,可以进行集成
# ensemble_pred, voting_pred = create_ensemble_model([model1, model2, model3], x_test, y_test)

自定义数据集

数据加载和预处理

python
def load_custom_dataset(data_dir, img_size=(224, 224), batch_size=32):
    """
    加载自定义数据集
    """
    # 创建数据生成器
    datagen = keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True,
        validation_split=0.2
    )
    
    # 训练数据
    train_generator = datagen.flow_from_directory(
        data_dir,
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='training'
    )
    
    # 验证数据
    validation_generator = datagen.flow_from_directory(
        data_dir,
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='validation'
    )
    
    return train_generator, validation_generator

def create_tf_dataset(data_dir, img_size=(224, 224), batch_size=32):
    """
    使用tf.data创建数据集
    """
    # 创建数据集
    train_ds = keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=123,
        image_size=img_size,
        batch_size=batch_size
    )
    
    val_ds = keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=123,
        image_size=img_size,
        batch_size=batch_size
    )
    
    # 数据预处理
    normalization_layer = keras.layers.Rescaling(1./255)
    
    train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
    val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))
    
    # 性能优化
    AUTOTUNE = tf.data.AUTOTUNE
    train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
    val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
    
    return train_ds, val_ds

# 使用示例
# train_ds, val_ds = create_tf_dataset('path/to/your/dataset')

模型部署

模型保存和加载

python
def save_model(model, model_path):
    """
    保存模型
    """
    # 保存完整模型
    model.save(f'{model_path}.h5')
    
    # 保存为SavedModel格式
    model.save(f'{model_path}_savedmodel')
    
    # 保存权重
    model.save_weights(f'{model_path}_weights.h5')
    
    print(f"模型已保存到: {model_path}")

def load_model(model_path):
    """
    加载模型
    """
    model = keras.models.load_model(f'{model_path}.h5')
    return model

def convert_to_tflite(model, model_path):
    """
    转换为TensorFlow Lite格式
    """
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    
    # 优化选项
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    
    # 转换
    tflite_model = converter.convert()
    
    # 保存
    with open(f'{model_path}.tflite', 'wb') as f:
        f.write(tflite_model)
    
    print(f"TFLite模型已保存到: {model_path}.tflite")

# 保存模型
save_model(basic_model, 'cifar10_classifier')
convert_to_tflite(basic_model, 'cifar10_classifier')

推理函数

python
def create_prediction_function(model, class_names):
    """
    创建预测函数
    """
    def predict_image(image_path):
        # 加载和预处理图像
        img = keras.utils.load_img(image_path, target_size=(32, 32))
        img_array = keras.utils.img_to_array(img)
        img_array = tf.expand_dims(img_array, 0) / 255.0
        
        # 预测
        predictions = model.predict(img_array)
        predicted_class = class_names[np.argmax(predictions[0])]
        confidence = float(np.max(predictions[0]))
        
        # 获取top-3预测
        top_3_indices = np.argsort(predictions[0])[-3:][::-1]
        top_3_predictions = [(class_names[i], float(predictions[0][i])) 
                           for i in top_3_indices]
        
        return {
            'predicted_class': predicted_class,
            'confidence': confidence,
            'top_3_predictions': top_3_predictions
        }
    
    return predict_image

def batch_predict(model, image_paths, class_names):
    """
    批量预测
    """
    results = []
    predict_fn = create_prediction_function(model, class_names)
    
    for image_path in image_paths:
        try:
            result = predict_fn(image_path)
            result['image_path'] = image_path
            results.append(result)
        except Exception as e:
            print(f"处理图像 {image_path} 时出错: {e}")
    
    return results

# 创建预测函数
predict_fn = create_prediction_function(basic_model, class_names)

# 示例使用
# result = predict_fn('path/to/test/image.jpg')
# print(result)

Web应用部署

python
def create_flask_app(model, class_names):
    """
    创建Flask Web应用
    """
    from flask import Flask, request, jsonify, render_template
    import base64
    from io import BytesIO
    from PIL import Image
    
    app = Flask(__name__)
    
    @app.route('/')
    def index():
        return render_template('index.html')
    
    @app.route('/predict', methods=['POST'])
    def predict():
        try:
            # 获取上传的图像
            if 'file' not in request.files:
                return jsonify({'error': '没有上传文件'})
            
            file = request.files['file']
            if file.filename == '':
                return jsonify({'error': '没有选择文件'})
            
            # 处理图像
            img = Image.open(file.stream)
            img = img.resize((32, 32))
            img_array = np.array(img) / 255.0
            img_array = np.expand_dims(img_array, 0)
            
            # 预测
            predictions = model.predict(img_array)
            predicted_class = class_names[np.argmax(predictions[0])]
            confidence = float(np.max(predictions[0]))
            
            # 获取所有类别的概率
            all_predictions = {class_names[i]: float(predictions[0][i]) 
                             for i in range(len(class_names))}
            
            return jsonify({
                'predicted_class': predicted_class,
                'confidence': confidence,
                'all_predictions': all_predictions
            })
            
        except Exception as e:
            return jsonify({'error': str(e)})
    
    return app

# 创建Flask应用
# app = create_flask_app(basic_model, class_names)
# app.run(debug=True)

总结

本章通过一个完整的图像分类项目,展示了深度学习项目的完整流程:

关键要点:

  1. 数据准备:数据加载、预处理和可视化
  2. 数据增强:提高模型泛化能力
  3. 模型设计:从基础CNN到迁移学习
  4. 训练优化:回调函数、超参数调优
  5. 模型评估:多维度性能分析
  6. 部署应用:模型保存、推理和Web应用

最佳实践:

  • 充分理解数据特征
  • 合理设计数据增强策略
  • 选择适合的模型架构
  • 使用适当的评估指标
  • 进行错误分析和模型优化
  • 考虑实际部署需求

下一章我们将学习文本分类项目,探索自然语言处理的应用。

本站内容仅供学习和研究使用。