Skip to content

TensorFlow 数据处理

tf.data API简介

tf.data API是TensorFlow中用于构建高效数据输入管道的核心工具。它提供了一套强大的工具来加载、转换和批处理数据,特别适合处理大规模数据集。

python
import tensorflow as tf
import numpy as np
import pandas as pd
from pathlib import Path

print(f"TensorFlow版本: {tf.__version__}")

创建数据集

1. 从内存数据创建

python
# 从张量创建数据集
data = tf.constant([1, 2, 3, 4, 5])
dataset = tf.data.Dataset.from_tensor_slices(data)

print("从张量创建的数据集:")
for element in dataset:
    print(element.numpy())

# 从多个张量创建
features = tf.constant([[1, 2], [3, 4], [5, 6]])
labels = tf.constant([0, 1, 0])
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

print("\n特征-标签数据集:")
for feature, label in dataset:
    print(f"特征: {feature.numpy()}, 标签: {label.numpy()}")

# 从字典创建
dataset_dict = tf.data.Dataset.from_tensor_slices({
    'features': [[1, 2], [3, 4], [5, 6]],
    'labels': [0, 1, 0]
})

print("\n字典格式数据集:")
for element in dataset_dict:
    print(f"特征: {element['features'].numpy()}, 标签: {element['labels'].numpy()}")

2. 从生成器创建

python
def data_generator():
    """数据生成器函数"""
    for i in range(10):
        yield i, i**2

# 从生成器创建数据集
dataset = tf.data.Dataset.from_generator(
    data_generator,
    output_signature=(
        tf.TensorSpec(shape=(), dtype=tf.int32),
        tf.TensorSpec(shape=(), dtype=tf.int32)
    )
)

print("从生成器创建的数据集:")
for x, y in dataset.take(5):
    print(f"x: {x.numpy()}, y: {y.numpy()}")

# 更复杂的生成器
def complex_generator():
    """复杂数据生成器"""
    for i in range(100):
        # 模拟复杂的数据生成逻辑
        features = np.random.randn(10).astype(np.float32)
        label = np.random.randint(0, 3)
        yield features, label

complex_dataset = tf.data.Dataset.from_generator(
    complex_generator,
    output_signature=(
        tf.TensorSpec(shape=(10,), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32)
    )
)

3. 从文件创建

python
# 从文本文件创建
def create_text_dataset():
    # 创建示例文本文件
    text_data = ["line 1", "line 2", "line 3", "line 4", "line 5"]
    with open('sample.txt', 'w') as f:
        for line in text_data:
            f.write(line + '\n')
    
    # 从文本文件创建数据集
    text_dataset = tf.data.TextLineDataset('sample.txt')
    
    print("文本文件数据集:")
    for line in text_dataset:
        print(line.numpy().decode('utf-8'))
    
    return text_dataset

# create_text_dataset()

# 从CSV文件创建
def create_csv_dataset():
    # 创建示例CSV数据
    csv_data = pd.DataFrame({
        'feature1': np.random.randn(100),
        'feature2': np.random.randn(100),
        'label': np.random.randint(0, 3, 100)
    })
    csv_data.to_csv('sample.csv', index=False)
    
    # 从CSV创建数据集
    csv_dataset = tf.data.experimental.make_csv_dataset(
        'sample.csv',
        batch_size=5,
        label_name='label',
        num_epochs=1,
        shuffle=False
    )
    
    print("CSV数据集:")
    for batch in csv_dataset.take(2):
        print("特征:", {k: v.numpy() for k, v in batch[0].items()})
        print("标签:", batch[1].numpy())
        print()
    
    return csv_dataset

# create_csv_dataset()

4. 从TFRecord文件创建

python
def create_tfrecord_dataset():
    """创建和读取TFRecord数据集"""
    
    # 创建TFRecord文件
    def create_example(feature, label):
        """创建tf.train.Example"""
        feature_dict = {
            'feature': tf.train.Feature(
                float_list=tf.train.FloatList(value=feature)
            ),
            'label': tf.train.Feature(
                int64_list=tf.train.Int64List(value=[label])
            )
        }
        example = tf.train.Example(
            features=tf.train.Features(feature=feature_dict)
        )
        return example.SerializeToString()
    
    # 写入TFRecord文件
    with tf.io.TFRecordWriter('sample.tfrecord') as writer:
        for i in range(100):
            feature = np.random.randn(10).astype(np.float32)
            label = np.random.randint(0, 3)
            example = create_example(feature, label)
            writer.write(example)
    
    # 解析函数
    def parse_example(example_proto):
        feature_description = {
            'feature': tf.io.FixedLenFeature([10], tf.float32),
            'label': tf.io.FixedLenFeature([], tf.int64)
        }
        return tf.io.parse_single_example(example_proto, feature_description)
    
    # 从TFRecord创建数据集
    tfrecord_dataset = tf.data.TFRecordDataset('sample.tfrecord')
    parsed_dataset = tfrecord_dataset.map(parse_example)
    
    print("TFRecord数据集:")
    for element in parsed_dataset.take(3):
        print(f"特征形状: {element['feature'].shape}, 标签: {element['label'].numpy()}")
    
    return parsed_dataset

# create_tfrecord_dataset()

数据集变换

1. 基本变换

python
# 创建基础数据集
dataset = tf.data.Dataset.range(10)

# map变换:对每个元素应用函数
squared_dataset = dataset.map(lambda x: x ** 2)
print("平方变换:")
for element in squared_dataset:
    print(element.numpy())

# filter变换:过滤元素
even_dataset = dataset.filter(lambda x: x % 2 == 0)
print("\n偶数过滤:")
for element in even_dataset:
    print(element.numpy())

# take变换:取前n个元素
first_five = dataset.take(5)
print("\n前5个元素:")
for element in first_five:
    print(element.numpy())

# skip变换:跳过前n个元素
skip_five = dataset.skip(5)
print("\n跳过前5个元素:")
for element in skip_five:
    print(element.numpy())

2. 批处理和重复

python
# 创建示例数据集
features = tf.random.normal([100, 10])
labels = tf.random.uniform([100], maxval=3, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

# 批处理
batched_dataset = dataset.batch(32)
print("批处理后的形状:")
for batch_features, batch_labels in batched_dataset.take(1):
    print(f"特征批次形状: {batch_features.shape}")
    print(f"标签批次形状: {batch_labels.shape}")

# 重复数据集
repeated_dataset = dataset.repeat(3)  # 重复3次
print(f"\n重复后的数据集大小: {len(list(repeated_dataset))}")

# 无限重复
infinite_dataset = dataset.repeat()  # 无限重复

# 打乱数据
shuffled_dataset = dataset.shuffle(buffer_size=100)
print("\n打乱数据集:")
for features, labels in shuffled_dataset.take(3):
    print(f"标签: {labels.numpy()}")

3. 复杂变换

python
# 创建序列数据集
sequence_dataset = tf.data.Dataset.range(20)

# window变换:创建滑动窗口
windowed_dataset = sequence_dataset.window(size=5, shift=2, drop_remainder=True)
windowed_dataset = windowed_dataset.flat_map(lambda window: window.batch(5))

print("滑动窗口:")
for window in windowed_dataset.take(3):
    print(window.numpy())

# flat_map变换:展平嵌套结构
nested_dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4], [5, 6]])
flattened_dataset = nested_dataset.flat_map(tf.data.Dataset.from_tensor_slices)

print("\n展平嵌套结构:")
for element in flattened_dataset:
    print(element.numpy())

# zip变换:组合多个数据集
dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(5, 10)
zipped_dataset = tf.data.Dataset.zip((dataset1, dataset2))

print("\n组合数据集:")
for x, y in zipped_dataset:
    print(f"({x.numpy()}, {y.numpy()})")

4. 数据预处理

python
def preprocess_image_data():
    """图像数据预处理示例"""
    
    # 模拟图像路径数据集
    image_paths = ['image1.jpg', 'image2.jpg', 'image3.jpg']
    labels = [0, 1, 2]
    
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
    
    def load_and_preprocess_image(path, label):
        # 在实际应用中,这里会加载真实图像
        # image = tf.io.read_file(path)
        # image = tf.image.decode_image(image, channels=3)
        
        # 模拟图像数据
        image = tf.random.normal([224, 224, 3])
        
        # 图像预处理
        image = tf.cast(image, tf.float32)
        image = tf.image.resize(image, [224, 224])
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_brightness(image, 0.2)
        image = tf.image.per_image_standardization(image)
        
        return image, label
    
    # 应用预处理
    processed_dataset = dataset.map(
        load_and_preprocess_image,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    
    return processed_dataset

def preprocess_text_data():
    """文本数据预处理示例"""
    
    texts = ["hello world", "tensorflow is great", "deep learning rocks"]
    labels = [0, 1, 1]
    
    dataset = tf.data.Dataset.from_tensor_slices((texts, labels))
    
    # 创建词汇表
    vocab = ["hello", "world", "tensorflow", "is", "great", "deep", "learning", "rocks"]
    vocab_table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(
            keys=vocab,
            values=tf.range(len(vocab), dtype=tf.int64)
        ),
        default_value=-1
    )
    
    def preprocess_text(text, label):
        # 分词(简化版)
        words = tf.strings.split(text)
        
        # 词汇映射
        word_ids = vocab_table.lookup(words)
        
        # 填充到固定长度
        word_ids = tf.pad(word_ids, [[0, 10 - tf.shape(word_ids)[0]]])[:10]
        
        return word_ids, label
    
    processed_dataset = dataset.map(preprocess_text)
    
    return processed_dataset

# 测试预处理
# image_dataset = preprocess_image_data()
# text_dataset = preprocess_text_data()

性能优化

1. 并行处理

python
# 创建大数据集进行性能测试
large_dataset = tf.data.Dataset.range(10000)

def slow_function(x):
    """模拟耗时操作"""
    tf.py_function(lambda: tf.numpy_function(lambda x: x**2, [x], tf.int64), [], tf.int64)
    return x ** 2

# 串行处理
serial_dataset = large_dataset.map(slow_function)

# 并行处理
parallel_dataset = large_dataset.map(
    slow_function,
    num_parallel_calls=tf.data.AUTOTUNE  # 自动调整并行度
)

# 手动指定并行度
manual_parallel_dataset = large_dataset.map(
    slow_function,
    num_parallel_calls=4  # 使用4个并行调用
)

print("并行处理配置完成")

2. 预取和缓存

python
def optimize_dataset_performance(dataset):
    """数据集性能优化"""
    
    # 缓存数据集(适用于小数据集)
    cached_dataset = dataset.cache()
    
    # 预取数据
    prefetched_dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    # 组合优化
    optimized_dataset = (dataset
                        .cache()                    # 缓存
                        .shuffle(1000)              # 打乱
                        .batch(32)                  # 批处理
                        .prefetch(tf.data.AUTOTUNE) # 预取
                        )
    
    return optimized_dataset

# 性能对比测试
def benchmark_dataset(dataset, num_epochs=3):
    """数据集性能基准测试"""
    import time
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        for batch in dataset:
            # 模拟训练步骤
            pass
    
    end_time = time.time()
    return end_time - start_time

# 创建测试数据集
test_dataset = tf.data.Dataset.range(1000).map(lambda x: x**2)

# 基础数据集
basic_time = benchmark_dataset(test_dataset.batch(32))

# 优化数据集
optimized_dataset = optimize_dataset_performance(test_dataset)
optimized_time = benchmark_dataset(optimized_dataset)

print(f"基础数据集时间: {basic_time:.2f}秒")
print(f"优化数据集时间: {optimized_time:.2f}秒")
print(f"性能提升: {basic_time/optimized_time:.2f}x")

3. 内存优化

python
def memory_efficient_dataset():
    """内存高效的数据集处理"""
    
    # 使用生成器避免一次性加载所有数据
    def data_generator():
        for i in range(100000):  # 大数据集
            yield np.random.randn(100).astype(np.float32), i % 10
    
    dataset = tf.data.Dataset.from_generator(
        data_generator,
        output_signature=(
            tf.TensorSpec(shape=(100,), dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32)
        )
    )
    
    # 流式处理,避免内存溢出
    processed_dataset = (dataset
                        .map(lambda x, y: (tf.nn.l2_normalize(x), y))
                        .batch(32)
                        .prefetch(2)  # 只预取2个批次
                        )
    
    return processed_dataset

# memory_efficient_dataset()

数据增强

1. 图像数据增强

python
def image_augmentation_pipeline():
    """图像数据增强管道"""
    
    def augment_image(image, label):
        # 随机翻转
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        
        # 随机旋转
        image = tf.image.rot90(image, k=tf.random.uniform([], 0, 4, dtype=tf.int32))
        
        # 颜色调整
        image = tf.image.random_brightness(image, 0.2)
        image = tf.image.random_contrast(image, 0.8, 1.2)
        image = tf.image.random_saturation(image, 0.8, 1.2)
        image = tf.image.random_hue(image, 0.1)
        
        # 随机裁剪和调整大小
        image = tf.image.random_crop(image, [200, 200, 3])
        image = tf.image.resize(image, [224, 224])
        
        # 标准化
        image = tf.cast(image, tf.float32) / 255.0
        image = tf.image.per_image_standardization(image)
        
        return image, label
    
    # 创建模拟图像数据集
    images = tf.random.normal([100, 224, 224, 3])
    labels = tf.random.uniform([100], maxval=10, dtype=tf.int32)
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    
    # 应用数据增强
    augmented_dataset = dataset.map(
        augment_image,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    
    return augmented_dataset

# augmented_dataset = image_augmentation_pipeline()

2. 文本数据增强

python
def text_augmentation_pipeline():
    """文本数据增强管道"""
    
    def augment_text(text, label):
        # 随机替换词汇(简化版)
        words = tf.strings.split(text)
        
        # 随机删除词汇
        num_words = tf.shape(words)[0]
        keep_prob = 0.9
        mask = tf.random.uniform([num_words]) < keep_prob
        filtered_words = tf.boolean_mask(words, mask)
        
        # 重新组合
        augmented_text = tf.strings.join(filtered_words, separator=' ')
        
        return augmented_text, label
    
    # 示例文本数据
    texts = tf.constant([
        "this is a great movie",
        "terrible film not recommended",
        "amazing story and acting"
    ])
    labels = tf.constant([1, 0, 1])
    
    dataset = tf.data.Dataset.from_tensor_slices((texts, labels))
    augmented_dataset = dataset.map(augment_text)
    
    return augmented_dataset

# text_augmented_dataset = text_augmentation_pipeline()

数据验证和调试

1. 数据集检查

python
def inspect_dataset(dataset, name="Dataset"):
    """检查数据集内容和结构"""
    print(f"\n=== {name} 检查 ===")
    
    # 检查数据集结构
    print(f"元素规格: {dataset.element_spec}")
    
    # 查看前几个样本
    print("前3个样本:")
    for i, element in enumerate(dataset.take(3)):
        if isinstance(element, tuple):
            print(f"样本 {i}: 特征形状={element[0].shape}, 标签={element[1].numpy()}")
        else:
            print(f"样本 {i}: {element.numpy()}")
    
    # 统计信息
    try:
        cardinality = tf.data.experimental.cardinality(dataset)
        if cardinality != tf.data.experimental.UNKNOWN_CARDINALITY:
            print(f"数据集大小: {cardinality.numpy()}")
        else:
            print("数据集大小: 未知")
    except:
        print("无法获取数据集大小")

# 创建测试数据集
test_features = tf.random.normal([50, 10])
test_labels = tf.random.uniform([50], maxval=3, dtype=tf.int32)
test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_labels))

inspect_dataset(test_dataset, "测试数据集")
inspect_dataset(test_dataset.batch(8), "批处理数据集")

2. 数据质量检查

python
def validate_data_quality(dataset):
    """数据质量验证"""
    
    def check_data_quality(features, labels):
        # 检查NaN值
        has_nan = tf.reduce_any(tf.math.is_nan(features))
        
        # 检查无穷值
        has_inf = tf.reduce_any(tf.math.is_inf(features))
        
        # 检查标签范围
        valid_labels = tf.logical_and(labels >= 0, labels < 10)
        
        # 记录问题
        tf.cond(
            has_nan,
            lambda: tf.print("警告: 发现NaN值"),
            lambda: tf.no_op()
        )
        
        tf.cond(
            has_inf,
            lambda: tf.print("警告: 发现无穷值"),
            lambda: tf.no_op()
        )
        
        tf.cond(
            tf.reduce_any(tf.logical_not(valid_labels)),
            lambda: tf.print("警告: 发现无效标签"),
            lambda: tf.no_op()
        )
        
        return features, labels
    
    validated_dataset = dataset.map(check_data_quality)
    return validated_dataset

# 创建包含问题的数据集进行测试
problematic_features = tf.constant([[1.0, float('nan')], [float('inf'), 3.0]])
problematic_labels = tf.constant([1, 15])  # 标签超出范围
problematic_dataset = tf.data.Dataset.from_tensor_slices((problematic_features, problematic_labels))

# validated_dataset = validate_data_quality(problematic_dataset)

实际应用示例

1. 完整的图像分类数据管道

python
def create_image_classification_pipeline(image_dir, batch_size=32, image_size=(224, 224)):
    """创建完整的图像分类数据管道"""
    
    # 获取图像路径和标签
    def get_image_paths_and_labels(image_dir):
        # 这里应该实现实际的文件扫描逻辑
        # 返回图像路径列表和对应的标签
        paths = ['img1.jpg', 'img2.jpg', 'img3.jpg']  # 示例
        labels = [0, 1, 2]  # 示例
        return paths, labels
    
    def load_and_preprocess_image(path, label):
        # 加载图像
        image = tf.io.read_file(path)
        image = tf.image.decode_image(image, channels=3)
        image = tf.cast(image, tf.float32)
        
        # 预处理
        image = tf.image.resize(image, image_size)
        image = tf.image.per_image_standardization(image)
        
        return image, label
    
    def augment_for_training(image, label):
        # 训练时的数据增强
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_brightness(image, 0.1)
        image = tf.image.random_contrast(image, 0.9, 1.1)
        return image, label
    
    # 获取数据
    paths, labels = get_image_paths_and_labels(image_dir)
    
    # 创建数据集
    dataset = tf.data.Dataset.from_tensor_slices((paths, labels))
    
    # 构建管道
    dataset = (dataset
              .map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
              .cache()  # 缓存预处理结果
              .map(augment_for_training, num_parallel_calls=tf.data.AUTOTUNE)
              .shuffle(1000)
              .batch(batch_size)
              .prefetch(tf.data.AUTOTUNE)
              )
    
    return dataset

# 使用示例
# train_dataset = create_image_classification_pipeline('train_images/')
# val_dataset = create_image_classification_pipeline('val_images/')

2. 文本分类数据管道

python
def create_text_classification_pipeline(texts, labels, vocab_size=10000, max_length=100):
    """创建文本分类数据管道"""
    
    # 创建词汇表
    tokenizer = tf.keras.preprocessing.text.Tokenizer(
        num_words=vocab_size,
        oov_token='<OOV>'
    )
    tokenizer.fit_on_texts(texts)
    
    def preprocess_text(text, label):
        # 分词和编码
        sequences = tokenizer.texts_to_sequences([text.numpy().decode('utf-8')])
        sequence = tf.constant(sequences[0], dtype=tf.int32)
        
        # 填充或截断到固定长度
        sequence = tf.pad(sequence, [[0, max_length - tf.shape(sequence)[0]]])[:max_length]
        
        return sequence, label
    
    # 创建数据集
    dataset = tf.data.Dataset.from_tensor_slices((texts, labels))
    
    # 预处理
    dataset = dataset.map(
        lambda text, label: tf.py_function(
            preprocess_text, [text, label], [tf.int32, tf.int32]
        ),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    
    return dataset, tokenizer

# 示例使用
sample_texts = ["这是一个好电影", "糟糕的体验", "非常推荐"]
sample_labels = [1, 0, 1]

# text_dataset, text_tokenizer = create_text_classification_pipeline(sample_texts, sample_labels)

总结

tf.data API是TensorFlow中处理数据的强大工具,主要特点包括:

  1. 多样的数据源:支持内存、文件、生成器等多种数据源
  2. 丰富的变换操作:map、filter、batch、shuffle等
  3. 性能优化:并行处理、预取、缓存等优化技术
  4. 数据增强:内置的图像和文本增强功能
  5. 易于调试:提供数据检查和验证工具

掌握tf.data API将大大提高你的数据处理效率和模型训练性能!

本站内容仅供学习和研究使用。