TensorFlow 数据处理
tf.data API简介
tf.data API是TensorFlow中用于构建高效数据输入管道的核心工具。它提供了一套强大的工具来加载、转换和批处理数据,特别适合处理大规模数据集。
python
import tensorflow as tf
import numpy as np
import pandas as pd
from pathlib import Path
print(f"TensorFlow版本: {tf.__version__}")创建数据集
1. 从内存数据创建
python
# 从张量创建数据集
data = tf.constant([1, 2, 3, 4, 5])
dataset = tf.data.Dataset.from_tensor_slices(data)
print("从张量创建的数据集:")
for element in dataset:
print(element.numpy())
# 从多个张量创建
features = tf.constant([[1, 2], [3, 4], [5, 6]])
labels = tf.constant([0, 1, 0])
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
print("\n特征-标签数据集:")
for feature, label in dataset:
print(f"特征: {feature.numpy()}, 标签: {label.numpy()}")
# 从字典创建
dataset_dict = tf.data.Dataset.from_tensor_slices({
'features': [[1, 2], [3, 4], [5, 6]],
'labels': [0, 1, 0]
})
print("\n字典格式数据集:")
for element in dataset_dict:
print(f"特征: {element['features'].numpy()}, 标签: {element['labels'].numpy()}")2. 从生成器创建
python
def data_generator():
"""数据生成器函数"""
for i in range(10):
yield i, i**2
# 从生成器创建数据集
dataset = tf.data.Dataset.from_generator(
data_generator,
output_signature=(
tf.TensorSpec(shape=(), dtype=tf.int32),
tf.TensorSpec(shape=(), dtype=tf.int32)
)
)
print("从生成器创建的数据集:")
for x, y in dataset.take(5):
print(f"x: {x.numpy()}, y: {y.numpy()}")
# 更复杂的生成器
def complex_generator():
"""复杂数据生成器"""
for i in range(100):
# 模拟复杂的数据生成逻辑
features = np.random.randn(10).astype(np.float32)
label = np.random.randint(0, 3)
yield features, label
complex_dataset = tf.data.Dataset.from_generator(
complex_generator,
output_signature=(
tf.TensorSpec(shape=(10,), dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.int32)
)
)3. 从文件创建
python
# 从文本文件创建
def create_text_dataset():
# 创建示例文本文件
text_data = ["line 1", "line 2", "line 3", "line 4", "line 5"]
with open('sample.txt', 'w') as f:
for line in text_data:
f.write(line + '\n')
# 从文本文件创建数据集
text_dataset = tf.data.TextLineDataset('sample.txt')
print("文本文件数据集:")
for line in text_dataset:
print(line.numpy().decode('utf-8'))
return text_dataset
# create_text_dataset()
# 从CSV文件创建
def create_csv_dataset():
# 创建示例CSV数据
csv_data = pd.DataFrame({
'feature1': np.random.randn(100),
'feature2': np.random.randn(100),
'label': np.random.randint(0, 3, 100)
})
csv_data.to_csv('sample.csv', index=False)
# 从CSV创建数据集
csv_dataset = tf.data.experimental.make_csv_dataset(
'sample.csv',
batch_size=5,
label_name='label',
num_epochs=1,
shuffle=False
)
print("CSV数据集:")
for batch in csv_dataset.take(2):
print("特征:", {k: v.numpy() for k, v in batch[0].items()})
print("标签:", batch[1].numpy())
print()
return csv_dataset
# create_csv_dataset()4. 从TFRecord文件创建
python
def create_tfrecord_dataset():
"""创建和读取TFRecord数据集"""
# 创建TFRecord文件
def create_example(feature, label):
"""创建tf.train.Example"""
feature_dict = {
'feature': tf.train.Feature(
float_list=tf.train.FloatList(value=feature)
),
'label': tf.train.Feature(
int64_list=tf.train.Int64List(value=[label])
)
}
example = tf.train.Example(
features=tf.train.Features(feature=feature_dict)
)
return example.SerializeToString()
# 写入TFRecord文件
with tf.io.TFRecordWriter('sample.tfrecord') as writer:
for i in range(100):
feature = np.random.randn(10).astype(np.float32)
label = np.random.randint(0, 3)
example = create_example(feature, label)
writer.write(example)
# 解析函数
def parse_example(example_proto):
feature_description = {
'feature': tf.io.FixedLenFeature([10], tf.float32),
'label': tf.io.FixedLenFeature([], tf.int64)
}
return tf.io.parse_single_example(example_proto, feature_description)
# 从TFRecord创建数据集
tfrecord_dataset = tf.data.TFRecordDataset('sample.tfrecord')
parsed_dataset = tfrecord_dataset.map(parse_example)
print("TFRecord数据集:")
for element in parsed_dataset.take(3):
print(f"特征形状: {element['feature'].shape}, 标签: {element['label'].numpy()}")
return parsed_dataset
# create_tfrecord_dataset()数据集变换
1. 基本变换
python
# 创建基础数据集
dataset = tf.data.Dataset.range(10)
# map变换:对每个元素应用函数
squared_dataset = dataset.map(lambda x: x ** 2)
print("平方变换:")
for element in squared_dataset:
print(element.numpy())
# filter变换:过滤元素
even_dataset = dataset.filter(lambda x: x % 2 == 0)
print("\n偶数过滤:")
for element in even_dataset:
print(element.numpy())
# take变换:取前n个元素
first_five = dataset.take(5)
print("\n前5个元素:")
for element in first_five:
print(element.numpy())
# skip变换:跳过前n个元素
skip_five = dataset.skip(5)
print("\n跳过前5个元素:")
for element in skip_five:
print(element.numpy())2. 批处理和重复
python
# 创建示例数据集
features = tf.random.normal([100, 10])
labels = tf.random.uniform([100], maxval=3, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# 批处理
batched_dataset = dataset.batch(32)
print("批处理后的形状:")
for batch_features, batch_labels in batched_dataset.take(1):
print(f"特征批次形状: {batch_features.shape}")
print(f"标签批次形状: {batch_labels.shape}")
# 重复数据集
repeated_dataset = dataset.repeat(3) # 重复3次
print(f"\n重复后的数据集大小: {len(list(repeated_dataset))}")
# 无限重复
infinite_dataset = dataset.repeat() # 无限重复
# 打乱数据
shuffled_dataset = dataset.shuffle(buffer_size=100)
print("\n打乱数据集:")
for features, labels in shuffled_dataset.take(3):
print(f"标签: {labels.numpy()}")3. 复杂变换
python
# 创建序列数据集
sequence_dataset = tf.data.Dataset.range(20)
# window变换:创建滑动窗口
windowed_dataset = sequence_dataset.window(size=5, shift=2, drop_remainder=True)
windowed_dataset = windowed_dataset.flat_map(lambda window: window.batch(5))
print("滑动窗口:")
for window in windowed_dataset.take(3):
print(window.numpy())
# flat_map变换:展平嵌套结构
nested_dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4], [5, 6]])
flattened_dataset = nested_dataset.flat_map(tf.data.Dataset.from_tensor_slices)
print("\n展平嵌套结构:")
for element in flattened_dataset:
print(element.numpy())
# zip变换:组合多个数据集
dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(5, 10)
zipped_dataset = tf.data.Dataset.zip((dataset1, dataset2))
print("\n组合数据集:")
for x, y in zipped_dataset:
print(f"({x.numpy()}, {y.numpy()})")4. 数据预处理
python
def preprocess_image_data():
"""图像数据预处理示例"""
# 模拟图像路径数据集
image_paths = ['image1.jpg', 'image2.jpg', 'image3.jpg']
labels = [0, 1, 2]
dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
def load_and_preprocess_image(path, label):
# 在实际应用中,这里会加载真实图像
# image = tf.io.read_file(path)
# image = tf.image.decode_image(image, channels=3)
# 模拟图像数据
image = tf.random.normal([224, 224, 3])
# 图像预处理
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, [224, 224])
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, 0.2)
image = tf.image.per_image_standardization(image)
return image, label
# 应用预处理
processed_dataset = dataset.map(
load_and_preprocess_image,
num_parallel_calls=tf.data.AUTOTUNE
)
return processed_dataset
def preprocess_text_data():
"""文本数据预处理示例"""
texts = ["hello world", "tensorflow is great", "deep learning rocks"]
labels = [0, 1, 1]
dataset = tf.data.Dataset.from_tensor_slices((texts, labels))
# 创建词汇表
vocab = ["hello", "world", "tensorflow", "is", "great", "deep", "learning", "rocks"]
vocab_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
keys=vocab,
values=tf.range(len(vocab), dtype=tf.int64)
),
default_value=-1
)
def preprocess_text(text, label):
# 分词(简化版)
words = tf.strings.split(text)
# 词汇映射
word_ids = vocab_table.lookup(words)
# 填充到固定长度
word_ids = tf.pad(word_ids, [[0, 10 - tf.shape(word_ids)[0]]])[:10]
return word_ids, label
processed_dataset = dataset.map(preprocess_text)
return processed_dataset
# 测试预处理
# image_dataset = preprocess_image_data()
# text_dataset = preprocess_text_data()性能优化
1. 并行处理
python
# 创建大数据集进行性能测试
large_dataset = tf.data.Dataset.range(10000)
def slow_function(x):
"""模拟耗时操作"""
tf.py_function(lambda: tf.numpy_function(lambda x: x**2, [x], tf.int64), [], tf.int64)
return x ** 2
# 串行处理
serial_dataset = large_dataset.map(slow_function)
# 并行处理
parallel_dataset = large_dataset.map(
slow_function,
num_parallel_calls=tf.data.AUTOTUNE # 自动调整并行度
)
# 手动指定并行度
manual_parallel_dataset = large_dataset.map(
slow_function,
num_parallel_calls=4 # 使用4个并行调用
)
print("并行处理配置完成")2. 预取和缓存
python
def optimize_dataset_performance(dataset):
"""数据集性能优化"""
# 缓存数据集(适用于小数据集)
cached_dataset = dataset.cache()
# 预取数据
prefetched_dataset = dataset.prefetch(tf.data.AUTOTUNE)
# 组合优化
optimized_dataset = (dataset
.cache() # 缓存
.shuffle(1000) # 打乱
.batch(32) # 批处理
.prefetch(tf.data.AUTOTUNE) # 预取
)
return optimized_dataset
# 性能对比测试
def benchmark_dataset(dataset, num_epochs=3):
"""数据集性能基准测试"""
import time
start_time = time.time()
for epoch in range(num_epochs):
for batch in dataset:
# 模拟训练步骤
pass
end_time = time.time()
return end_time - start_time
# 创建测试数据集
test_dataset = tf.data.Dataset.range(1000).map(lambda x: x**2)
# 基础数据集
basic_time = benchmark_dataset(test_dataset.batch(32))
# 优化数据集
optimized_dataset = optimize_dataset_performance(test_dataset)
optimized_time = benchmark_dataset(optimized_dataset)
print(f"基础数据集时间: {basic_time:.2f}秒")
print(f"优化数据集时间: {optimized_time:.2f}秒")
print(f"性能提升: {basic_time/optimized_time:.2f}x")3. 内存优化
python
def memory_efficient_dataset():
"""内存高效的数据集处理"""
# 使用生成器避免一次性加载所有数据
def data_generator():
for i in range(100000): # 大数据集
yield np.random.randn(100).astype(np.float32), i % 10
dataset = tf.data.Dataset.from_generator(
data_generator,
output_signature=(
tf.TensorSpec(shape=(100,), dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.int32)
)
)
# 流式处理,避免内存溢出
processed_dataset = (dataset
.map(lambda x, y: (tf.nn.l2_normalize(x), y))
.batch(32)
.prefetch(2) # 只预取2个批次
)
return processed_dataset
# memory_efficient_dataset()数据增强
1. 图像数据增强
python
def image_augmentation_pipeline():
"""图像数据增强管道"""
def augment_image(image, label):
# 随机翻转
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
# 随机旋转
image = tf.image.rot90(image, k=tf.random.uniform([], 0, 4, dtype=tf.int32))
# 颜色调整
image = tf.image.random_brightness(image, 0.2)
image = tf.image.random_contrast(image, 0.8, 1.2)
image = tf.image.random_saturation(image, 0.8, 1.2)
image = tf.image.random_hue(image, 0.1)
# 随机裁剪和调整大小
image = tf.image.random_crop(image, [200, 200, 3])
image = tf.image.resize(image, [224, 224])
# 标准化
image = tf.cast(image, tf.float32) / 255.0
image = tf.image.per_image_standardization(image)
return image, label
# 创建模拟图像数据集
images = tf.random.normal([100, 224, 224, 3])
labels = tf.random.uniform([100], maxval=10, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
# 应用数据增强
augmented_dataset = dataset.map(
augment_image,
num_parallel_calls=tf.data.AUTOTUNE
)
return augmented_dataset
# augmented_dataset = image_augmentation_pipeline()2. 文本数据增强
python
def text_augmentation_pipeline():
"""文本数据增强管道"""
def augment_text(text, label):
# 随机替换词汇(简化版)
words = tf.strings.split(text)
# 随机删除词汇
num_words = tf.shape(words)[0]
keep_prob = 0.9
mask = tf.random.uniform([num_words]) < keep_prob
filtered_words = tf.boolean_mask(words, mask)
# 重新组合
augmented_text = tf.strings.join(filtered_words, separator=' ')
return augmented_text, label
# 示例文本数据
texts = tf.constant([
"this is a great movie",
"terrible film not recommended",
"amazing story and acting"
])
labels = tf.constant([1, 0, 1])
dataset = tf.data.Dataset.from_tensor_slices((texts, labels))
augmented_dataset = dataset.map(augment_text)
return augmented_dataset
# text_augmented_dataset = text_augmentation_pipeline()数据验证和调试
1. 数据集检查
python
def inspect_dataset(dataset, name="Dataset"):
"""检查数据集内容和结构"""
print(f"\n=== {name} 检查 ===")
# 检查数据集结构
print(f"元素规格: {dataset.element_spec}")
# 查看前几个样本
print("前3个样本:")
for i, element in enumerate(dataset.take(3)):
if isinstance(element, tuple):
print(f"样本 {i}: 特征形状={element[0].shape}, 标签={element[1].numpy()}")
else:
print(f"样本 {i}: {element.numpy()}")
# 统计信息
try:
cardinality = tf.data.experimental.cardinality(dataset)
if cardinality != tf.data.experimental.UNKNOWN_CARDINALITY:
print(f"数据集大小: {cardinality.numpy()}")
else:
print("数据集大小: 未知")
except:
print("无法获取数据集大小")
# 创建测试数据集
test_features = tf.random.normal([50, 10])
test_labels = tf.random.uniform([50], maxval=3, dtype=tf.int32)
test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_labels))
inspect_dataset(test_dataset, "测试数据集")
inspect_dataset(test_dataset.batch(8), "批处理数据集")2. 数据质量检查
python
def validate_data_quality(dataset):
"""数据质量验证"""
def check_data_quality(features, labels):
# 检查NaN值
has_nan = tf.reduce_any(tf.math.is_nan(features))
# 检查无穷值
has_inf = tf.reduce_any(tf.math.is_inf(features))
# 检查标签范围
valid_labels = tf.logical_and(labels >= 0, labels < 10)
# 记录问题
tf.cond(
has_nan,
lambda: tf.print("警告: 发现NaN值"),
lambda: tf.no_op()
)
tf.cond(
has_inf,
lambda: tf.print("警告: 发现无穷值"),
lambda: tf.no_op()
)
tf.cond(
tf.reduce_any(tf.logical_not(valid_labels)),
lambda: tf.print("警告: 发现无效标签"),
lambda: tf.no_op()
)
return features, labels
validated_dataset = dataset.map(check_data_quality)
return validated_dataset
# 创建包含问题的数据集进行测试
problematic_features = tf.constant([[1.0, float('nan')], [float('inf'), 3.0]])
problematic_labels = tf.constant([1, 15]) # 标签超出范围
problematic_dataset = tf.data.Dataset.from_tensor_slices((problematic_features, problematic_labels))
# validated_dataset = validate_data_quality(problematic_dataset)实际应用示例
1. 完整的图像分类数据管道
python
def create_image_classification_pipeline(image_dir, batch_size=32, image_size=(224, 224)):
"""创建完整的图像分类数据管道"""
# 获取图像路径和标签
def get_image_paths_and_labels(image_dir):
# 这里应该实现实际的文件扫描逻辑
# 返回图像路径列表和对应的标签
paths = ['img1.jpg', 'img2.jpg', 'img3.jpg'] # 示例
labels = [0, 1, 2] # 示例
return paths, labels
def load_and_preprocess_image(path, label):
# 加载图像
image = tf.io.read_file(path)
image = tf.image.decode_image(image, channels=3)
image = tf.cast(image, tf.float32)
# 预处理
image = tf.image.resize(image, image_size)
image = tf.image.per_image_standardization(image)
return image, label
def augment_for_training(image, label):
# 训练时的数据增强
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, 0.1)
image = tf.image.random_contrast(image, 0.9, 1.1)
return image, label
# 获取数据
paths, labels = get_image_paths_and_labels(image_dir)
# 创建数据集
dataset = tf.data.Dataset.from_tensor_slices((paths, labels))
# 构建管道
dataset = (dataset
.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
.cache() # 缓存预处理结果
.map(augment_for_training, num_parallel_calls=tf.data.AUTOTUNE)
.shuffle(1000)
.batch(batch_size)
.prefetch(tf.data.AUTOTUNE)
)
return dataset
# 使用示例
# train_dataset = create_image_classification_pipeline('train_images/')
# val_dataset = create_image_classification_pipeline('val_images/')2. 文本分类数据管道
python
def create_text_classification_pipeline(texts, labels, vocab_size=10000, max_length=100):
"""创建文本分类数据管道"""
# 创建词汇表
tokenizer = tf.keras.preprocessing.text.Tokenizer(
num_words=vocab_size,
oov_token='<OOV>'
)
tokenizer.fit_on_texts(texts)
def preprocess_text(text, label):
# 分词和编码
sequences = tokenizer.texts_to_sequences([text.numpy().decode('utf-8')])
sequence = tf.constant(sequences[0], dtype=tf.int32)
# 填充或截断到固定长度
sequence = tf.pad(sequence, [[0, max_length - tf.shape(sequence)[0]]])[:max_length]
return sequence, label
# 创建数据集
dataset = tf.data.Dataset.from_tensor_slices((texts, labels))
# 预处理
dataset = dataset.map(
lambda text, label: tf.py_function(
preprocess_text, [text, label], [tf.int32, tf.int32]
),
num_parallel_calls=tf.data.AUTOTUNE
)
return dataset, tokenizer
# 示例使用
sample_texts = ["这是一个好电影", "糟糕的体验", "非常推荐"]
sample_labels = [1, 0, 1]
# text_dataset, text_tokenizer = create_text_classification_pipeline(sample_texts, sample_labels)总结
tf.data API是TensorFlow中处理数据的强大工具,主要特点包括:
- 多样的数据源:支持内存、文件、生成器等多种数据源
- 丰富的变换操作:map、filter、batch、shuffle等
- 性能优化:并行处理、预取、缓存等优化技术
- 数据增强:内置的图像和文本增强功能
- 易于调试:提供数据检查和验证工具
掌握tf.data API将大大提高你的数据处理效率和模型训练性能!