2025-03-26 Using Subsets in PyTorch
Summary : This notebook demonstrates how to create and use subsets of datasets in PyTorch. Subsets are particularly useful for working with specific portions of your data (e.g., validation sets or smaller batches for experimentation). We will use the MNIST dataset as an example.
Using subsets in PyTorch allows you to work with specific portions of your dataset efficiently. This is particularly useful for tasks like validation, testing, or when working with large datasets where loading the entire dataset into memory is impractical.
Key Points:
- Efficiency: Using Subset avoids creating copies of your data, which saves memory and processing time.
- Flexibility: You can create multiple subsets from the same dataset for different tasks (e.g., validation, testing).
- Integration: Subsets work seamlessly with PyTorch’s DataLoader, making it easy to integrate into training loops.
By leveraging subsets in PyTorch, you can efficiently manage and experiment with portions of your data without compromising on performance or memory usage.
Let's first initialize the notebook:
import torch
# https://pytorch.org/docs/stable/nn.html
from torchvision import datasets, transforms
# Set random seed for reproducibility (good practice)
torch.manual_seed(42);
Step 1: Load the Full Dataset¶
We start by loading the full MNIST dataset. The MNIST
class from torchvision.datasets
handles downloading and preprocessing the data.
# Define the dataset parameters
root = '/tmp/data' # Root directory where data will be stored
train = True # Use training data (set to False for test data)
transform = transforms.ToTensor() # Convert images to tensors
# Load the full dataset
full_dataset = datasets.MNIST(root=root, train=train, transform=transform, download=True)
# Print basic information about the dataset
print(f"Full dataset size: {len(full_dataset)} samples")
Step 2: Create a Subset of the Dataset¶
Use the Subset
class to create smaller portions of your dataset. Here, we create a subset containing the first 100 samples.
from torch.utils.data import DataLoader, Subset
# Define the indices for the subset (first 100 samples)
subset_indices = list(range(100)) # You can modify this range as needed
# Create the subset dataset
subset_dataset = Subset(full_dataset, subset_indices)
# Print information about the subset
print(f"Subset size: {len(subset_dataset)} samples");
Step 3: Load the Subset Using DataLoader¶
Load the subset into batches using a DataLoader
. This is useful for training loops or data processing.
# Define parameters for loading the data
batch_size = 32
shuffle = True # Shuffle the data in each epoch
# Create the DataLoader for the subset
subset_loader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=shuffle)
# Print information about the DataLoader
print(f"Subset DataLoader: {len(subset_loader)} batches per epoch");
Step 4: Iterate Over the Subset¶
Use a simple loop to iterate over the subset and retrieve batches of data.
# Initialize a model (example: simple linear classifier)
input_size = 28 * 28 # MNIST images are 28x28
output_size = 10 # There are 10 classes in MNIST
model = torch.nn.Linear(input_size, output_size)
# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Training loop (example: one epoch)
for epoch in range(5):
print(f'\nEpoch {epoch + 1}')
for batch_idx, (data, labels) in enumerate(subset_loader):
# Flatten the images
data = data.view(-1, input_size)
# Forward pass
outputs = model(data)
loss = criterion(outputs, labels)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
print(f'Batch {batch_idx}: Loss: {loss.item():.4f}')
print("\nTraining complete!");
Step 5: Optional - Create Multiple Subsets for Different Purposes¶
You can create multiple subsets for different purposes (e.g., training, validation).
# Example: Creating a validation subset
val_indices = list(range(100, 200)) # Next 100 samples as validation set
val_dataset = Subset(full_dataset, val_indices)
print(f"Validation subset size: {len(val_dataset)} samples");
Step 6: Conclusion¶
You’ve now seen how to create and use subsets of a PyTorch dataset. Subsets are super handy for:
- Playing around with smaller chunks of data.
- Setting up validation sets or custom splits.
- Keeping things light on memory when dealing with big datasets.
An alternative: random_split
¶
For cross-validation, one may also simply use:
rho = .8
len_dataset = len(full_dataset)
len_train = int(rho*len_dataset)
train_dataset, test_dataset = torch.utils.data.random_split(range(len_dataset),
[len_train, len_dataset-len_train],
generator=torch.Generator().manual_seed(42),
)
More things to explore
generator=torch.Generator().manual_seed(42)
sampler = torch.utils.data.RandomSampler(full_dataset, replacement=True, num_samples=10000, generator=generator)
train_loader = torch.utils.data.DataLoader(full_dataset, batch_size=32,
sampler=sampler)
some book keeping for the notebook¶
%load_ext watermark
%watermark -i -h -m -v -p torch,torchvision -r -g -b