Yi-Hsuan Tseng

Thoughts in Progress

Loading image dataset in Pytorch

02 Mar 2023 » Pytorch, ML

Imports

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

1. Transform

Common transforms include image resizing, data augmentation, value standarization and normalization.

1.1 Data Augmentation

  • Ex: RandomCrop, RandomRotate, RandomHorizontalFlip…
  • In general, only for training dataset.
  • But if there includes resizing the image to specific size, then it must also be done on the testing data.
  • Example: If transforms.CenterCrop(224) is done on training, then it must also be done on the testing to make them both at the same size: 224 x 224.

1.2 Convert to torch tensor

Convert a PIL image or ndarray to tensor, and scale the values to [0.0, 1.0] accordingly.

  • Necessary for both training & testing dataset

1.3 Normalization

For better and more stable training, normalize each channel from [0.0, 1.0] to [-1.0, 1.0]. This is done by transforms.Normalize(means, stds).

# For rgb images (3 channels)
# Normalize each channel
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
  • Necessary for both training & testing dataset

1.4 Example:

train_T = transforms.Compose([
    transforms.Resize(255),
    transforms.RandomRotation(30),
    transforms.CenterCrop(224),                     # random crop and resize to 224 x 224
    transforms.ToTensor(),                                 # convert image to tensor and scale values to [0, 1]
    transforms.Normalize([0.5, 0.5, 0.5],
                         [0.5, 0.5, 0.5])                  # normalize each channel value to [-1,1]
])

val_T = transforms.accimageCompose([
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.ToTensor(),                                 # convert image to tensor and scale values to [0, 1]
    transforms.Normalize([0.5, 0.5, 0.5],
                         [0.5, 0.5, 0.5])                  # normalize each channel value to [-1,1]
])

Choosing 224 here is just due to the fact that most pretrained models accept input of this size. So by resizing images to this size would allow using pretrained models for transfer learning afterwards.

2. ImageFolder

ds = datasets.ImageFolder('path/to/data', transform=transform)

Images of each class are in separate folders. The file level should be like

root/dog/dog_0.png
root/dog/dog_1.png
...

root/cat/cat_0.png
root/cat/cat_1.png
...

3. DataLoader

dataloader = DataLoader(ds, batch_size=32, shuffle=True)

Every time we loop through the dataloader, it outputs a batch of data with labels

for images, labels in dataloader:
    ...

Related Posts