Train an image classification model using fastai

Can you believe it? On the very first class of fast.ai course: Practical Deep Learning for Coders 2022, you will use deep learning to do image classification. Amazing! How to do it?

This was done by using the fastai library, which was build on top of PyTorch and other libraries, so we beginners can use the module to build a model and “play the whole game”. As we learn more, we’ll dig into more details.

I followed Professor Jeremy Howard’s example code ( Is it a bird? Creating a model from your own data | Kaggle) to train a model to tell whether the fruit in a picture is an apple or a peach. Here’s my code: Which fruit is it, apple or peach? | Kaggle. You may have a look.

The whole process contains four steps, as shown in the figure below:

Figure1. training a deep learning model pipeline

Step 0: Collect data

If you don’t have a dataset, you need to start from this step. In the fast.ai course example code, DuckDuckGo or Bing search api is used to search for images.

Step 1: Prepare the data

In fastai a DataBlock class is provides as a template for you to define the data, then the dataloaders pack the data into batches so that we can do the training in parallel using GPU.

dls = DataBlock(here to set some parameters).dataloaders(path,bs=32)

We’ll talk more about DataBlock later.

Step 2: Training the model

fastai does the training for you, and you just need to set data(dls), model(resnet18) and metrics. It takes only two lines of code:

learn = vision_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(5)

In a minute, the model will be trained.

To accelerate the training, here we use a pre-trained model and finetune it. You may wonder, what is a pre-trained model?

We do the training to find a set of optimal parameters for the model. A pre-trained model has been trained on a large dataset and sets a good basis, so we don’t need to find the optimal parameters from start, and the training process is accelerated.

Step 3: Use the model to predict

Now it’s time to test your model on a new image.

categ,idx,probs = learn.predict(PILImage.create('test-photo.jpg'))
print(f"This is a photo of {categ}, with a prbability of {probs[idx]:.4f}")

Dig into DataBlock

You may suppose you’ll spend most of the time on the training step, things like model architecture, learning rate and so on. However, Jeremy said that DataBlock is the key thing you want to get familiar with as a beginner, because “the main thing you’re going to be trying to figure out is how do I get this data into my model?”

That’s true. When I tried to do a data competition or a little project, I often found the data cleaning and preparation were tricky. Sometimes I got stuck on this stage.

So let’s see how to define DataBlock . As to computer vision problems, we need to answer the following questions:

  • What kinds of data are the inputs(data) and outputs(label)? Defined by blocks.
  • Where is the input data? Defined by get_items.
  • Do we need to apply something to the input? Defined by get_x.
  • How to get the label? Defined by get_y.
  • How to split train and valid data? Defined by splitter.
  • Do we need to apply something on formed items? Defined by item_tfms, it’s often used to resize the images to be the same size.
  • Do we need to apply something on formed batches? Defined by batch_tfms, it’s often used for data augmentation.

For example, in the fastai example code, different categories of images are stored in different folders for training, the dataloaders is constructed using the following code:

dls = DataBlock(blocks=(ImageBlock,CategoryBlock),
               get_items=get_image_files,
               get_y=parent_label,
               splitter=RandomSplitter(valid_pct=0.2,seed=42),
               item_tfms=Resize(192)
               ).dataloaders(path,bs=32)

Here (ImageBlock,CategoryBlock) means that the input is image, and the output(label) is categorical variable. In this dataset, different categories of images are stored in different folders, so get_y defines to be parent_label.

blocks, get_items, get_y, splitter, item_tfms are the parameters you’re going to set every time using DataBlock.

We can set batch transform to the same dataset as follow:

dls = DataBlock(blocks=(ImageBlock,CategoryBlock),
               get_items=get_image_files,
               get_y=parent_label,
               splitter=RandomSplitter(valid_pct=0.2,seed=42),
               item_tfms=RandomResizeCrop(192,min_scale=0.5)
               batch_tfms=aug_transforms()
               ).dataloaders(path,bs=32)

What’s more, the label defining methods may be different for different datasets. For example, in the PETS dataset, the label is defined by the lowercase or uppercase of the file name.

path = untar_data(URLs.PETS)/"images"
def label_func(fname):
    return "cat" if fname.name[0].isupper() else "dog"
dls = DataBlock(blocks=(ImageBlock,CategoryBlock),
                get_items=get_image_files,
                get_y=label_func,
                splitter=RandomSplitter(valid_pct=0.2,seed=42),
                item_tfms=Resize(192)
               ).dataloaders(path)

If you look at the fastai example code, you may notice that it uses dls=ImageDataLoaders.from_name_func() to define dataloaders. It’s a factory method for some datasets. When I first learned it, I got a little confused and didn’t know which function to use, so my advice is to learn DataBlock, as it can be used for all kinds of datasets and saves your time.

For more details, you may read fastai - Data block tutorial.