Load, explore, and prepare the famous MNIST dataset for neural network training
MNIST (Modified National Institute of Standards and Technology) is the "Hello World" of computer vision. It's a dataset of 70,000 handwritten digits (0-9) that's perfect for learning neural networks because it's:
TensorFlow makes it incredibly easy to load MNIST. It's built right into Keras datasets!
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Load the MNIST dataset
print("Loading MNIST dataset...")
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# Print dataset information
print(f"Training data shape: {x_train.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Test data shape: {x_test.shape}")
print(f"Test labels shape: {y_test.shape}")
# Print data types and value ranges
print(f"Image data type: {x_train.dtype}")
print(f"Label data type: {y_train.dtype}")
print(f"Pixel value range: {x_train.min()} to {x_train.max()}")
print(f"Unique labels: {np.unique(y_train)}")
Before training any model, it's crucial to understand your data. Let's visualize some MNIST images and analyze their distribution.
# Function to display MNIST images
def plot_mnist_samples(images, labels, num_samples=10):
"""Display a grid of MNIST images with their labels"""
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
axes = axes.ravel()
for i in range(num_samples):
axes[i].imshow(images[i], cmap='gray')
axes[i].set_title(f'Label: {labels[i]}')
axes[i].axis('off')
plt.tight_layout()
plt.show()
# Display first 10 training images
print("Sample training images:")
plot_mnist_samples(x_train, y_train, 10)
# Analyze the distribution of labels
import collections
# Count occurrences of each digit
train_counts = collections.Counter(y_train)
test_counts = collections.Counter(y_test)
print("Training set label distribution:")
for digit in range(10):
print(f"Digit {digit}: {train_counts[digit]} samples")
print("\nTest set label distribution:")
for digit in range(10):
print(f"Digit {digit}: {test_counts[digit]} samples")
# Visualize distribution
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.bar(train_counts.keys(), train_counts.values())
plt.title('Training Set Distribution')
plt.xlabel('Digit')
plt.ylabel('Count')
plt.subplot(1, 2, 2)
plt.bar(test_counts.keys(), test_counts.values())
plt.title('Test Set Distribution')
plt.xlabel('Digit')
plt.ylabel('Count')
plt.tight_layout()
plt.show()
Raw data is rarely ready for machine learning. We need to preprocess it to improve training performance and accuracy.
Neural networks work best with normalized data. We'll scale pixel values from [0, 255] to [0, 1].
# Normalize pixel values to range [0, 1]
print("Before normalization:")
print(f"Min value: {x_train.min()}, Max value: {x_train.max()}")
# Convert to float32 and normalize
x_train_normalized = x_train.astype('float32') / 255.0
x_test_normalized = x_test.astype('float32') / 255.0
print("After normalization:")
print(f"Min value: {x_train_normalized.min()}, Max value: {x_train_normalized.max()}")
print(f"Data type: {x_train_normalized.dtype}")
# Compare original vs normalized image
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.imshow(x_train[0], cmap='gray')
ax1.set_title(f'Original (0-255)\nLabel: {y_train[0]}')
ax1.axis('off')
ax2.imshow(x_train_normalized[0], cmap='gray')
ax2.set_title(f'Normalized (0-1)\nLabel: {y_train[0]}')
ax2.axis('off')
plt.tight_layout()
plt.show()
Normalization helps because:
Different neural network architectures expect different input shapes. Let's prepare our data for both dense and convolutional networks.
# Option 1: Flatten for Dense/Fully Connected networks
x_train_flat = x_train_normalized.reshape(x_train_normalized.shape[0], -1)
x_test_flat = x_test_normalized.reshape(x_test_normalized.shape[0], -1)
print("Flattened shapes:")
print(f"Training: {x_train_flat.shape}")
print(f"Test: {x_test_flat.shape}")
# Option 2: Add channel dimension for Convolutional networks
x_train_conv = x_train_normalized.reshape(x_train_normalized.shape[0], 28, 28, 1)
x_test_conv = x_test_normalized.reshape(x_test_normalized.shape[0], 28, 28, 1)
print("\nConvolutional shapes:")
print(f"Training: {x_train_conv.shape}")
print(f"Test: {x_test_conv.shape}")
# Visualize the difference
print("\nShape comparison:")
print(f"Original: {x_train_normalized.shape} -> Flattened: {x_train_flat.shape}")
print(f"784 = 28 x 28 (width x height)")
print(f"Channel dimension: {x_train_conv.shape} (samples, height, width, channels)")
For multi-class classification, we convert integer labels to one-hot encoded vectors.
# Convert labels to one-hot encoding
num_classes = 10
y_train_onehot = tf.keras.utils.to_categorical(y_train, num_classes)
y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes)
print("Label encoding comparison:")
print(f"Original labels: {y_train[:5]}")
print(f"One-hot encoded:")
for i in range(5):
print(f" {y_train[i]} -> {y_train_onehot[i]}")
print(f"\nShapes:")
print(f"Original: {y_train.shape}")
print(f"One-hot: {y_train_onehot.shape}")
# Visualize one-hot encoding
plt.figure(figsize=(10, 6))
sample_idx = 0
digit_label = y_train[sample_idx]
plt.subplot(1, 2, 1)
plt.imshow(x_train[sample_idx], cmap='gray')
plt.title(f'Digit: {digit_label}')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.bar(range(10), y_train_onehot[sample_idx])
plt.title('One-Hot Encoding')
plt.xlabel('Digit Class')
plt.ylabel('Value')
plt.xticks(range(10))
plt.tight_layout()
plt.show()
TensorFlow's tf.data API provides efficient data loading and preprocessing pipelines.
# Create TensorFlow datasets
batch_size = 32
# Create training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train_normalized, y_train))
train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
# Create test dataset
test_dataset = tf.data.Dataset.from_tensor_slices((x_test_normalized, y_test))
test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
print(f"Training batches: {len(train_dataset)}")
print(f"Test batches: {len(test_dataset)}")
# Inspect a batch
for batch_x, batch_y in train_dataset.take(1):
print(f"Batch shape: {batch_x.shape}")
print(f"Labels shape: {batch_y.shape}")
print(f"First few labels: {batch_y[:5].numpy()}")
break
Data augmentation creates variations of existing data to improve model generalization.
# Simple data augmentation for MNIST
def augment_data(image, label):
"""Apply random transformations to images"""
# Random rotation (±10 degrees)
image = tf.image.rot90(image, k=tf.random.uniform([], 0, 4, dtype=tf.int32))
# Random brightness adjustment
image = tf.image.random_brightness(image, 0.1)
# Ensure values stay in [0, 1] range
image = tf.clip_by_value(image, 0.0, 1.0)
return image, label
# Apply augmentation to training data
train_dataset_aug = train_dataset.map(augment_data)
# Visualize augmented samples
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
original_batch = next(iter(train_dataset))
augmented_batch = next(iter(train_dataset_aug))
for i in range(5):
# Original
axes[0, i].imshow(original_batch[0][i], cmap='gray')
axes[0, i].set_title(f'Original: {original_batch[1][i].numpy()}')
axes[0, i].axis('off')
# Augmented
axes[1, i].imshow(augmented_batch[0][i], cmap='gray')
axes[1, i].set_title(f'Augmented: {augmented_batch[1][i].numpy()}')
axes[1, i].axis('off')
plt.tight_layout()
plt.show()
Now it's your turn! Complete this data preprocessing pipeline:
# Your solution here
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# 1. Load MNIST
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# 2. Distribution function
def plot_digit_distribution(labels, title="Digit Distribution"):
"""Plot the distribution of digits in the dataset"""
# Your code here
pass
# 3. Normalize and reshape
def preprocess_data(x_train, x_test, y_train, y_test):
"""Normalize and reshape the data"""
# Your code here
return x_train_processed, x_test_processed, y_train_processed, y_test_processed
# 4. Train/validation split
def create_train_val_split(x_train, y_train, val_split=0.2):
"""Split training data into train and validation sets"""
# Your code here
return x_train_split, x_val_split, y_train_split, y_val_split
# 5. Create TensorFlow datasets
def create_datasets(x_train, y_train, x_val, y_val, x_test, y_test, batch_size=32):
"""Create TensorFlow datasets"""
# Your code here
return train_ds, val_ds, test_ds
# Test your functions
plot_digit_distribution(y_train, "Training Set Distribution")
# ... rest of your implementation
In this lesson, you've learned the essential data preprocessing steps:
Next, we'll use this preprocessed MNIST data to build our first neural network from scratch!