Introduction To TorchData The Best Way To Load Data In PyTorch
TLDR
This post is an introduction to TorchData: a library to build better data loading pipelines in PyTorch. It’s intended to replace the usual Dataset + DataLoader approach. I will illustrate this introduction with a concrete notebook example hosted on GitHub: we will prepare and load data from a real dataset called Intel Image Classification using TorchData.
What is data loading in machine learning?
We can see it as a pipeline or sequence of operations needed to transform data from its storage up to be ready for consumption by a model for training or inference.
The data can be stored in many possible storages, like local disk, cloud buckets, or pulled from an API. After reading the data, an arbitrarily complex process transforms it into a model-ready format, like tensors.
The problem with DataLoader
The DataLoader is monolithic. It bundles too many features in the same class and makes composition and reusability difficult. Indeed there are hundreds of ways to store a given dataset, and each requires a highly customized DataLoader: there are many primitive formats like tensors .pt files, pickle, json ..., data can be grouped by folders, filename, regex, file headers, ... the files can be compressed in different formats, ... and so on.
To support these cases, we need a custom or highly configured DataLoader. By all means we would prefer to avoid writing again and again the same code to handle all these specific operations in different contexts.
What is TorchData?
TorchData gives us an elegant solution: it provides single-purpose and composable data loading primitives. They are assembled in pipelines to match arbitrary complex data loading schemes to restrict the amount of custom code to the minimum possible. The core concept of the library is the DataPipe: it’s a renamed and repurposed implementation of the Dataset for composed usage.
There are two kind of DataPipe.
Firstly the IterDataPipe represents an updated version of IterDataset. They implement the __iter__
method: you can iterate over them, but you can’t access their items individually by index. They are well-suited for stream datasets where random reads are expensive.
As a simple example we build an IterDatapipe which starts from a range of integers and group them in 2 batches of even and odd numbers. To do so it uses the functional groupby DataPipe which returns the numbers modulo 2 as a key. Then we iterate an print the result.
# IterDataPipe of the 10 first int grouped in 2 batches: even and odd
pipe = (
# Wrap the range into an IterDataPipe wrapper
dp.iter.IterableWrapper(range(10))
# Groupby parity: one batch for even and on batch for odd numbers
.groupby(lambda x: x % 2)
)
# We can iterate over the items
print("Complete iteration:", list(pipe))
# But we can't access them individually based on an index
# pipe[0] would raise an Exception
# Outputs:
# >> Complete iteration: [[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]]
Secondly the MapDataPipe represents an updated version of the MapDataset. They are well-suited for key-value datasets, where random reads are cheap. You can iterate over them. and access their items individually by index.
As a simple example we build a MapDataPipe which starts from a range of integers and simply multiply every element by 2 using a functional map DataPipe and shuffles the output. Then we iterate, print the result and show that we can access items by index.
# MapDataPipe of the 10 first integers multiplied by 2 and shuffled
pipe = (
# Wrap the range into an MapDataPipe wrapper
dp.map.SequenceWrapper(range(10))
# Multiply every number by 2
.map(lambda x: x * 2)
# Shuffle
.shuffle()
)
# We can iterate over the values
print("Complete iteration:", list(pipe))
# We can also access items individually based on their index
print("Index based access:", pipe[0], pipe[9])
# Outputs
# >> Complete iteration: [4, 8, 18, 10, 14, 12, 16, 6, 0, 2]
# >> Index based access: 4 2
Concrete example: Intel Image Classification
Let's see a real example: loading the Intel Image Classification dataset with the usual DataLoader first, and then using the TorchData library. You can read the entire code from the repository.
Intel Image Classification dataset samples. Intel Image Classification dataset samples.
We first install the tool that we need for this project. Many are already present on the Google Colab setup, but we need to add kaggle, tochvision and torchdata.
!pip install kaggle torchdata torchvision
Since we use a Kaggle dataset, we upload our API key generated from the Kaggle website. It is necessary to use their python client to easily download the dataset.
from google.colab import files
files.upload()
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
Then we get the dataset with a simple command and we uncompress it to a folder named data.
!kaggle datasets download puneet6060/intel-image-classification
If everything worked correctly, you should see 6 folders, one for each class, in the train and test folders under data.
!ls data/seg_test/seg_test
# Output
# >> buildings forest glacier mountain sea street
Setup
We start by importing some libraries that we will use later for the project: glob, itertools, pathlib, torch and torchvision.
import glob
import itertools as it
from pathlib import Path
import torch
import torchvision
Then we create some utilities for later in the project.
We have 2 possible splits, train and test and we create a dict to convert a split to its corresponding path on the disk.
Regarding the dataset, we have 6 different classes and we create a dict to convert each class to an integer label.
All images don’t have the same dimensions, so we need a transforms to resize them to a fixed size. 150 by 150. Finally, we write a function to give us an image class based on the path.
The first parent folder is the class name, so we just take the stem of the first parent. The stem is simply the last part of the path.
# Convert split name to folder path
split_to_path = {
"train": "data/seg_train/seg_train",
"test": "data/seg_test/seg_test"
}
# Convert class name to int label
name_to_label = {
"buildings": 0,
"forest": 1,
"glacier": 2,
"mountain": 3,
"sea": 4,
"street": 5,
}
# Image transformations to get all images at size 150 x 150
transforms = torch.nn.Sequential(
torchvision.transforms.Resize((150, 150))
)
def img_path_to_label(path: str):
"""Function to get the class from the file path"""
name = Path(path).parents[0].stem
return name_to_label[name]
Usual data loading: Dataset and DataLoader
We have everything to get started. We start with the traditional data loading implementation with a Dataset and a DataLoader.
We first import the Dataset and DataLoader classes from torch.utils.data
.
from torch.utils.data import Dataset, DataLoader
The first step is to create our IntelDataset. In the __init__
method we just grab the path based on the split using the function defined before.
We define a method called list files, which simply list all the images at the given path.
The Dataset interface requires to implement the __len__
method returning the size. Here it’s simply the number of images.
The last step here is to implement getitem which retrieves a tuple of image and label given an index: we get the file path at the received index, then we load the image using torchvision, we get the image label using our utility function, and finally we get the resized image with the corresponding label.
class IntelDataset(Dataset):
"""Class to represent the Intel Image Classification as a Dataset"""
def __init__(self, split: str):
# Get the split path (train or test) from the split name.
self.path = split_to_path[split]
def _list_files(self):
"""List all images"""
return list(glob.glob(f"{self.path}/**/*.jpg"))
def __len__(self):
"""Get the lenght of the dataset"""
return len(self._list_files())
def __getitem__(self, idx: int):
"""Method to access a tuple (input, label) per index"""
# Get all file paths
files = self._list_files()
# Get the file path at the received index
file_path = files[idx]
# Load the image
image = torchvision.io.read_image(file_path)
# Get the label from the image path
label = img_path_to_label(file_path)
# Return the transformed image with its label
return transforms(image), label
We have our Dataset implementation, we can now create the DataLoader with shuffling and a batch size of 10 items. To check that everything is working as expected we iterate over the first 5 batches, and print the size as well as the the labels. We can see here that we get the expected batch size and shuffled dataset.
# Create the Dataset for the train split
ds = IntelDataset("train")
# Create the DataLoader with shuffling and batching
dl = DataLoader(ds, batch_size=10, shuffle=True)
# Iterate over the 5 first batches
for X, y in it.islice(dl, 5):
print(f"X batch length: {len(X)}, y batch length: {len(y)}, labels: {y}")
# Outputs:
# >> X batch length: 10, y batch length: 10, labels: tensor([1, 1, 3, 1, 2, 3, 5, 4, 4, 2])
# >> X batch length: 10, y batch length: 10, labels: tensor([4, 4, 4, 2, 4, 4, 5, 5, 2, 1])
# >> X batch length: 10, y batch length: 10, labels: tensor([4, 3, 4, 3, 3, 2, 4, 1, 1, 4])
# >> X batch length: 10, y batch length: 10, labels: tensor([1, 4, 0, 4, 4, 5, 1, 3, 5, 0])
# >> X batch length: 10, y batch length: 10, labels: tensor([3, 0, 5, 5, 1, 1, 2, 2, 1, 0])
Loading data like a pro: TorchData
Analyzing the code above, we can identify some primitive that should ideally be implemented in a libary: listing files in a folder, mapping a function on outputs of dataset. With the DataLoader approach all the logic is inside the __getitem__
method.
Let’s now look at the elegant solution proposed by TorchData: combining modular primitives with a functional API.
We start by importing the datapipes module.
import torchdata.datapipes as dp
from torch.utils.data import default_collate
We create a function called build_datapipes which return a MapDataPipes given a split name. We begin by using our utility function to retrieve the split folder path.
We start the pipeline we recursively listing all the files at the given path. It gives us an IterDataPipe.
Then we map a function to return tuples of image path and the corresponding label.
We would prefer a MapDataPipe for this dataset to enable index based access. To do so we need an index for each element. We get it using the enumerate pipe which enumerate lazily every items in order. We can then call to_map_datapipe to convert our IterDataPipe to a MapDataPipe.
We can now read the image into a tensor with torchvision. Remember we need to resize the images to a fixed dimension.
At this point we shuffle the data by chaining a shuffler data pipe.
After batching by 10 items, we get a pipeline yielding list of 10 tuples. Every tuple being an image with a label. But ideally we want only to tensors, one containing the batch of images and one containing the batch of labels. To do so we apply the default collate function from torch.utils.data. It transforms a list of tensor tuples to a tuple of tensors.
def build_datapipes(split: str):
"""Function to return the DataPipe based on the split name"""
# Get the split path (train or test) from the split name.
path = split_to_path[split]
return (
# Iterate over all file paths
dp.iter.FileLister(path, recursive=True)
# Transform path to tuples (path, label)
.map(lambda x: (x, img_path_to_label(x)))
# We need a key to tranform or IterDataPipes to a MapDataPipes
# Enumerate will yield: (index, (path, label))
.enumerate()
# Get a MapDataPipes, it's like a dictionary with key based access
.to_map_datapipe()
# Read the image and yield (image tensor, label)
.map(lambda x: (torchvision.io.read_image(x[0]), x[1]))
# Resize the image using our tranform (transformed image, label)
.map(lambda x: (transforms(x[0]), x[1]))
# Shuffle the DataPipes
.shuffle()
# Get batches of 10
.batch(10)
# Collate the batches. Transforms [(image, label)] to
# (images, labels)
.map(lambda x: default_collate(x))
And that’s it we have our data pipe ready to load our dataset. We iterate over the first 5 batches to check that it works. And as expected we get batches of 10 shuffled items.
pipe = build_datapipes("train")
# Iterate over the 5 first batches
for X, y in it.islice(pipe, 5):
print(f"X batch length: {len(X)}, y batch length: {len(y)}, labels: {y}")
# Outputs:
# >> X batch length: 10, y batch length: 10, labels: tensor([5, 4, 0, 3, 5, 2, 1, 3, 1, 5])
# >> X batch length: 10, y batch length: 10, labels: tensor([3, 4, 3, 0, 5, 0, 0, 0, 3, 5])
# >> X batch length: 10, y batch length: 10, labels: tensor([4, 4, 1, 4, 2, 5, 2, 1, 3, 5])
# >> X batch length: 10, y batch length: 10, labels: tensor([3, 5, 2, 0, 4, 1, 3, 0, 5, 3])
# >> X batch length: 10, y batch length: 10, labels: tensor([3, 5, 4, 3, 5, 1, 2, 4, 5, 4])
The TorchData library allows reusability while preserving flexibility since you can assemble primitives into complex pipelines.
Conclusion
In a few words, data loading is an important part of any machine learning project. However the usual approach in PyTorch with the Dataset and DataLoader is flawed: the DataLoader is monolithic and makes reusability and composability difficult.
TorchData is a great solution bringing modular and composable data loading primitives: we end up writting less custom code, since the library abstracts a lot of common use cases into DataPipes.
Try TorchData in your next project and never go back.