Custom dataset in PyTorch

PyTorch provides a dataset (torch.utils.data.Dataset) and dataloader class (torch.utils.data.DataLoader). You can write a full-fledged commercial-grade application even without using them. Why do we need them then?

You can find the answer by looking at the comparison below. Both doing same thing — with and without using PyTorch’s dataset/dataloader classes. You don’t need to understand the codes — I just need you to agree with me that the right-side one looks much cleaner.

In short, we use PyTorch’s dataset, dataloader mainly to get rid of the complexity of keeping track of variables for batch, epoch etc. Also that nasty index-calculations. There are ofcourse other advantages (transforms, data split, shuffle on demand etc.).

But it comes with a price — a little one. In order to be able to use these cool features, you have to define a mechanism so that PyTorch understands your dataset like in it’s format. In simple words, writing a class extending (torch.utils.data.Dataset)

Writing your Dataset class

PyTorch has an extensive tutorial here covering almost all aspects. But I find it a little hard for beginners, especially since it involves file/image processing which obscures the dataset implementation a bit.

So, I thought to write a tutorial with an extremely easy data format so that we can concentrate only on dataset features. And we will use my favorite case — “Celsius to Fahrenheit”.

Here is the code:

Forensic Analysis:

  1. See in Line#5, we are extending (inheriting) Dataset class, which is an abstract class, and we have to implement __len__ and __getitem__.
  2. In line #8, 9 — we are generating our dummy dataset. For more complicated cases, for example, images, where loading all data together is not realistic, you might want to declare a data structure here to link your actual data. For the image example, you could use an array containing all image file names.
  3. In line 10–14, we are splitting our dataset — considering 10% data for testing, and the rest of them are for training
  4. __len__: It should return the length of your data
  5. __getitem__(n): In should return the n-th data from your dataset

Let’s Test out Dataset

Change the different parameters here and observe the output. See the power of PyTorch Dataset, and DataLoader!