Active learning made simple using Flash and BaaL
Presenting Bayesian Active Learning (BaaL) with Lightning Flash to train faster with fewer samples.
Lightning Flash is a PyTorch AI Factory built on top of PyTorch Lightning. Flash helps you quickly develop strong baselines on your data across multiple tasks and data modalities.
BaaL is a bayesian active learning library developed at ElementAI.
In this blog post, you will learn how Active Learning works, how to utilize BaaL Active Learning components with Lightning Flash to train faster or with fewer samples.
What is Active learning and how does it work?
Active learning is a method able to reduce the amount of labeling effort by using a machine learning model to query the user for specific inputs. This is traditionally done in 4 repetitive steps described as follows:
Computing how informative a new unlabelled sample can be, is the key to Active Learning and it is where BaaL comes in. BaaL provides BayesianHeuristic
to determine how informative a given sample is by using the model predictions.
Putting together all the components might seem challenging, but with Lightning Flash and its BaaL integration, you will see how simple it all becomes.
3 Steps Experiment — Image Classifier on CIFAR10
The CIFAR-10 dataset consists of 60k 32x32 color images in 10 classes, with approximately 6k images per class. The dataset is divided into 50k training images and 10k test images.
Here is an image showing some examples.
Cifar 10 Examples Source: https://www.cs.toronto.edu/~kriz/cifar.html
In this experiment, we aim at reproducing the results from the paper `Bayesian active learning for production, a systematic study and a reusable library` by Parmida Atighehchian, Frederic Branchaud-Charron, and Alexandre Lacoste.
For this experiment, we won’t request the data to be labeled by an annotator but use the ground-truth label of the training dataset instead, e.g, we mask the labels and un-masked them when the heuristic determines the associated unlabelled sample should be labeled. This is a common trick used by Active Learning researchers to test out their ideas. In a real-world scenario, the data will be labeled by a human.
The informativeness or uncertainty estimation in this experiment would be done either randomly
or using the BALD
heuristic from BaaL. A heuristic is a method that derived the informativeness of a given unlabelled sample based on the model predictions.
Step 1: Prepare your data
For this experiment, we will create a CIFAR10 dataset using torchvision
.
After loading the data we will apply minimal augmentation, simply random horizontal flip and 30 degrees rotation for the training dataset define as follows:
Finally, we use the ImageClassificationData
to load the datasets we defined above.
Step 2: Prepare your model
In the paper, the head classifier below is created as a sequence of linear layers. The final layer has a dimension of 10 equal to the number of classes within your CIFAR10 dataset.
Using Lightning Flash, we can easily create an ImageClassifier
with a pretrainedvgg16
backbone and SGD
optimizer.
Step 3: Create the Active Learning Components
Lightning Flash provides components to utilize your data and model to experiment with Active Learning.
These components will take care of:
- Initial labeling — Mask labels from the training dataset
- Training, validation, testing, and prediction loops- Implement an Active Learning Cycle.
- Enabling the uncertainty estimation with MC-Dropout.
- Emulating labeling of the top-K most uncertain examples.
- Resetting the model parameters after each cycle.
And the best of all is that you get all of this with only a few lines of code.
By using an ActiveLearningDataModule
, you can wrap your existing data module and emulate the above active learning scenario.
For this experiment, we will start training with only 1024 labeled data samples out of the 50k and request 100 new samples to be labeled at every cycle.
Finally, we will create a Flash Trainer
alongside with ActiveLearningLoop
and use it to replace the base fit_loop
of the Trainer. For this experimentation, we will perform 2500 cycles where the model will be trained from scratch for 20
epochs each time.
Active Learning Results
For this experimentation, we choose to use BALD
and random
heuristic and as it provides a good tradeoff between efficiency and performance. Here is a table containing more heuristics available within BaaL alongside a short description.
You can find more advanced benchmarks from the BaaL Team within their papers.
Using the ImageClassifier with Baal BALD
heuristic, we can observe that it takes 3.3 times fewer data to achieve the same loss which means the model predictions can be used to estimate the uncertainty for each sample and identify harder ones.
Find the full example code here.
Addendum: Active learning with mislabelled data
To show how good active learning is, we shuffled λ % of the labels and ran the same experiment. In the figure below, we show that even with 10% of the labels corrupted, BALD is still stronger than random with no noise! These kinds of experiments are incredibly easy to run using the Lightning Flash integration as we only changed the dataset composition!
Conclusion
In this tutorial, we used Lightning Flash Active learning integration using BaaL to run an experiment on CIFAR10.
Using only ~50 lines of code, we created a complete experiment and we were able to demonstrate that BALD evaluation results in better results than uniform sampling.
What’s Next?
BaaL and Lightning Flash Team are working closely to provide seamless integration for more data modalities and tasks. Stay tuned!
Built by the PyTorch Lightning creators, let us introduce you to Grid.ai. Our platform enables you to scale your model training without worrying about infrastructure, similarly as Lightning automates the training.
You can get started with Grid.ai for free with just a GitHub or Google Account.