2025-03-26 Using Subsets in PyTorch

Summary : This notebook demonstrates how to create and use subsets of datasets in PyTorch. Subsets are particularly useful for working with specific portions of your data (e.g., validation sets or smaller batches for experimentation). We will use the MNIST dataset as an example.

Using subsets in PyTorch allows you to work with specific portions of your dataset efficiently. This is particularly useful for tasks like validation, testing, or when working with large datasets where loading the entire dataset into memory is impractical.

Key Points:

  • Efficiency: Using Subset avoids creating copies of your data, which saves memory and processing time.
  • Flexibility: You can create multiple subsets from the same dataset for different tasks (e.g., validation, testing).
  • Integration: Subsets work seamlessly with PyTorch’s DataLoader, making it easy to integrate into training loops.

By leveraging subsets in PyTorch, you can efficiently manage and experiment with portions of your data without compromising on performance or memory usage.

Let's first initialize the notebook:

In [1]:
import torch
# https://pytorch.org/docs/stable/nn.html
from torchvision import datasets, transforms

# Set random seed for reproducibility (good practice)
torch.manual_seed(42);

Step 1: Load the Full Dataset

We start by loading the full MNIST dataset. The MNIST class from torchvision.datasets handles downloading and preprocessing the data.

In [2]:
# Define the dataset parameters
root = '/tmp/data'  # Root directory where data will be stored
train = True      # Use training data (set to False for test data)
transform = transforms.ToTensor()  # Convert images to tensors

# Load the full dataset
full_dataset = datasets.MNIST(root=root, train=train, transform=transform, download=True)

# Print basic information about the dataset
print(f"Full dataset size: {len(full_dataset)} samples")
Full dataset size: 60000 samples

Step 2: Create a Subset of the Dataset

Use the Subset class to create smaller portions of your dataset. Here, we create a subset containing the first 100 samples.

In [3]:
from torch.utils.data import DataLoader, Subset

# Define the indices for the subset (first 100 samples)
subset_indices = list(range(100))  # You can modify this range as needed

# Create the subset dataset
subset_dataset = Subset(full_dataset, subset_indices)

# Print information about the subset
print(f"Subset size: {len(subset_dataset)} samples");
Subset size: 100 samples

Step 3: Load the Subset Using DataLoader

Load the subset into batches using a DataLoader. This is useful for training loops or data processing.

In [4]:
# Define parameters for loading the data
batch_size = 32
shuffle = True  # Shuffle the data in each epoch

# Create the DataLoader for the subset
subset_loader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=shuffle)

# Print information about the DataLoader
print(f"Subset DataLoader: {len(subset_loader)} batches per epoch");
Subset DataLoader: 4 batches per epoch

Step 4: Iterate Over the Subset

Use a simple loop to iterate over the subset and retrieve batches of data.

In [5]:
# Initialize a model (example: simple linear classifier)
input_size = 28 * 28  # MNIST images are 28x28
output_size = 10      # There are 10 classes in MNIST

model = torch.nn.Linear(input_size, output_size)

# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Training loop (example: one epoch)
for epoch in range(5):
    print(f'\nEpoch {epoch + 1}')
    for batch_idx, (data, labels) in enumerate(subset_loader):
        # Flatten the images
        data = data.view(-1, input_size)
        
        # Forward pass
        outputs = model(data)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}: Loss: {loss.item():.4f}')

print("\nTraining complete!");
Epoch 1
Batch 0: Loss: 2.2630

Epoch 2
Batch 0: Loss: 2.2093

Epoch 3
Batch 0: Loss: 2.2010

Epoch 4
Batch 0: Loss: 2.0449

Epoch 5
Batch 0: Loss: 2.0448

Training complete!

Step 5: Optional - Create Multiple Subsets for Different Purposes

You can create multiple subsets for different purposes (e.g., training, validation).

In [6]:
# Example: Creating a validation subset
val_indices = list(range(100, 200))  # Next 100 samples as validation set
val_dataset = Subset(full_dataset, val_indices)

print(f"Validation subset size: {len(val_dataset)} samples");
Validation subset size: 100 samples

Step 6: Conclusion

You’ve now seen how to create and use subsets of a PyTorch dataset. Subsets are super handy for:

  • Playing around with smaller chunks of data.
  • Setting up validation sets or custom splits.
  • Keeping things light on memory when dealing with big datasets.

An alternative: random_split

For cross-validation, one may also simply use:

In [7]:
rho = .8
len_dataset = len(full_dataset)
len_train = int(rho*len_dataset)

train_dataset, test_dataset = torch.utils.data.random_split(range(len_dataset),
                                                           [len_train, len_dataset-len_train], 
                                                           generator=torch.Generator().manual_seed(42),
                                                           )

More things to explore

In [8]:
generator=torch.Generator().manual_seed(42)
sampler = torch.utils.data.RandomSampler(full_dataset, replacement=True, num_samples=10000, generator=generator)
train_loader = torch.utils.data.DataLoader(full_dataset, batch_size=32, 
                                           sampler=sampler)

some book keeping for the notebook

In [9]:
%load_ext watermark
%watermark -i -h -m -v -p torch,torchvision  -r -g -b
Python implementation: CPython
Python version       : 3.13.2
IPython version      : 8.30.0

torch      : 2.6.0
torchvision: 0.21.0

Compiler    : Clang 16.0.0 (clang-1600.0.26.6)
OS          : Darwin
Release     : 24.3.0
Machine     : arm64
Processor   : arm
CPU cores   : 10
Architecture: 64bit

Hostname: obiwan.local

Git hash: 83fd93189577c8e6ac256a705efe435477eebe18

Git repo: https://github.com/laurentperrinet/sciblog

Git branch: master