Active learning made simple using Flash and BaaL

Presenting Bayesian Active Learning (BaaL) with Lightning Flash to train faster with fewer samples.

PyTorch Lightning team
PyTorch Lightning Developer Blog
6 min readDec 21, 2021

--

Lightning Flash is a PyTorch AI Factory built on top of PyTorch Lightning. Flash helps you quickly develop strong baselines on your data across multiple tasks and data modalities.

BaaL is a bayesian active learning library developed at ElementAI.

In this blog post, you will learn how Active Learning works, how to utilize BaaL Active Learning components with Lightning Flash to train faster or with fewer samples.

What is Active learning and how does it work?

Active learning is a method able to reduce the amount of labeling effort by using a machine learning model to query the user for specific inputs. This is traditionally done in 4 repetitive steps described as follows:

It is quite common for this cycle to be repeated thousands of times or even to be fully continuous as a new sample is being labeled.

Computing how informative a new unlabelled sample can be, is the key to Active Learning and it is where BaaL comes in. BaaL provides BayesianHeuristic to determine how informative a given sample is by using the model predictions.

Putting together all the components might seem challenging, but with Lightning Flash and its BaaL integration, you will see how simple it all becomes.

3 Steps Experiment — Image Classifier on CIFAR10

The CIFAR-10 dataset consists of 60k 32x32 color images in 10 classes, with approximately 6k images per class. The dataset is divided into 50k training images and 10k test images.

Here is an image showing some examples.

Cifar 10 Examples Source: https://www.cs.toronto.edu/~kriz/cifar.html

In this experiment, we aim at reproducing the results from the paper `Bayesian active learning for production, a systematic study and a reusable library` by Parmida Atighehchian, Frederic Branchaud-Charron, and Alexandre Lacoste.

For this experiment, we won’t request the data to be labeled by an annotator but use the ground-truth label of the training dataset instead, e.g, we mask the labels and un-masked them when the heuristic determines the associated unlabelled sample should be labeled. This is a common trick used by Active Learning researchers to test out their ideas. In a real-world scenario, the data will be labeled by a human.

The informativeness or uncertainty estimation in this experiment would be done either randomly or using the BALD heuristic from BaaL. A heuristic is a method that derived the informativeness of a given unlabelled sample based on the model predictions.

Step 1: Prepare your data

For this experiment, we will create a CIFAR10 dataset using torchvision.

Code snippet to generate the CIFAR datasets.

After loading the data we will apply minimal augmentation, simply random horizontal flip and 30 degrees rotation for the training dataset define as follows:

Finally, we use the ImageClassificationData to load the datasets we defined above.

Step 2: Prepare your model

In the paper, the head classifier below is created as a sequence of linear layers. The final layer has a dimension of 10 equal to the number of classes within your CIFAR10 dataset.

Using Lightning Flash, we can easily create an ImageClassifier with a pretrainedvgg16 backbone and SGD optimizer.

Step 3: Create the Active Learning Components

Lightning Flash provides components to utilize your data and model to experiment with Active Learning.

These components will take care of:

  1. Initial labeling — Mask labels from the training dataset
  2. Training, validation, testing, and prediction loops- Implement an Active Learning Cycle.
  3. Enabling the uncertainty estimation with MC-Dropout.
  4. Emulating labeling of the top-K most uncertain examples.
  5. Resetting the model parameters after each cycle.

And the best of all is that you get all of this with only a few lines of code.

By using an ActiveLearningDataModule, you can wrap your existing data module and emulate the above active learning scenario.

For this experiment, we will start training with only 1024 labeled data samples out of the 50k and request 100 new samples to be labeled at every cycle.

Finally, we will create a Flash Trainer alongside with ActiveLearningLoop and use it to replace the base fit_loop of the Trainer. For this experimentation, we will perform 2500 cycles where the model will be trained from scratch for 20 epochs each time.

Active Learning Results

For this experimentation, we choose to use BALD and random heuristic and as it provides a good tradeoff between efficiency and performance. Here is a table containing more heuristics available within BaaL alongside a short description.

You can find more advanced benchmarks from the BaaL Team within their papers.

Using the ImageClassifier with Baal BALD heuristic, we can observe that it takes 3.3 times fewer data to achieve the same loss which means the model predictions can be used to estimate the uncertainty for each sample and identify harder ones.

Find the full example code here.

Addendum: Active learning with mislabelled data

To show how good active learning is, we shuffled λ % of the labels and ran the same experiment. In the figure below, we show that even with 10% of the labels corrupted, BALD is still stronger than random with no noise! These kinds of experiments are incredibly easy to run using the Lightning Flash integration as we only changed the dataset composition!

Conclusion

In this tutorial, we used Lightning Flash Active learning integration using BaaL to run an experiment on CIFAR10.

Using only ~50 lines of code, we created a complete experiment and we were able to demonstrate that BALD evaluation results in better results than uniform sampling.

What’s Next?

BaaL and Lightning Flash Team are working closely to provide seamless integration for more data modalities and tasks. Stay tuned!

Built by the PyTorch Lightning creators, let us introduce you to Grid.ai. Our platform enables you to scale your model training without worrying about infrastructure, similarly as Lightning automates the training.

You can get started with Grid.ai for free with just a GitHub or Google Account.

--

--

We are the core contributors team developing PyTorch Lightning — the deep learning research framework to run complex models without the boilerplate