PyTorch Lightning Developer Blog

PyTorch Lightning is a lightweight machine learning framework that handles most of the engineering work, leaving you to focus on the science. Check it out: pytorchlightning.ai

Follow publication

Train anything with Lightning custom Loops

With the new Lightning Loop API in v1.5, you can write your own training loops for any kind of research from active learning to recommendation systems.

--

With PyTorch Lightning v1.5, we’re thrilled to introduce our new Loop API allowing users to customize Lightning to fit any kind of research, from sequential learning to recommendation models.

This is part of our effort to make Lightning the simplest, most flexible framework to take any kind of deep learning research to production. Continue reading to learn how to do cross-validation, active learning, or any type of research.

Find the full documentation here.

Lightning Loops Under The Hood

PyTorch Lightning was created to do the hard work for you. The Lightning Trainer automates all the mechanics of the training, validation, and test routines. To create your model, all you need to do is define the architecture and the training, validation, and test steps and Lightning will make sure to call the right thing at the right time.

Internally, the Lightning Trainer relies on a series of nested loops to properly conduct the gradient descent optimization that applies to 90%+ of machine learning use cases. Even though Lightning provides hundreds of features, behind the scenes, it looks like this:

https://gist.github.com/tchaton/e3e9da65a8d21cf07af69cbf3bcf022b

However, some new research use cases such as meta-learning, active learning, cross-validation, recommendation systems, etc., require a different loop structure.

To resolve this, the Lightning Team implemented a general while-loop as a python class, the Lightning Loop. Here is its pseudocode and its full implementation can be found there.

https://gist.github.com/tchaton/f8856c0d8a66de2bddfba74bf7dd8a74

Using Loops has several advantages:

  • You can replace, subclass, or wrap any loops within Lightning to customize their inner workings to your needs. This makes it possible to express any type of research with Lightning.
  • The Loops are standardized and each loop can be isolated from its parent and children. With a simple loop, you might end up with more code, but when dealing with hundreds of features, this structure is the key to scale while preserving a high level of flexibility.
  • The Loop can track its state and save its state within the model checkpoint. This is used with fault-tolerant training to enable auto restart.

Using Customized Loops

The community has already started using customized loops for a variety of use cases:

Cross-Validation

KFold, Cross-Validation is a machine learning practice in which the training dataset is partitioned into several complementary subsets, so-called folds. One cross-validation round will perform fitting where one fold is left out for validation and the other folds are used for training.

The KFoldLoop contains the logic to perform cross-validation and can easily be attached to the trainer as follows:

https://gist.github.com/tchaton/88117dea360ad975671b661ceefe7877

Find the full example here.

Active Learning

Active learning is a machine learning practice in which the user interacts with the learner to provide new labels when required.

Credit to the Baal Team.

In Lightning Flash, you can find an ActiveLearningLoop, an implementation that you can use together with an ActiveLearningDataModule to label new data on the fly. To run the following demo, install Flash and BaaL first:

https://github.com/PyTorchLightning/lightning-flash/blob/78faff0e70170fb9fac8296954d656d5a5c632f3/flash_examples/integrations/baal/image_classification_active_learning.py

Find the full example here.

Yield In LightningModule Training Step

This YieldLoop enables you to write the training_step() hook as a Python Generator, i.e., you can yield the loss values instead of returning them. This can enable more elegant and expressive implementations when using multiple optimizers. Here is a GAN example where both the generator and discriminator losses get yielded:

https://gist.github.com/tchaton/9555da4687cef28ffca35cd8fc9e0aba

Find the full example here.

Implement Your Own Loop

If you have an innovative use case that requires a different way to iterate over data, you can implement custom loops from scratch. Here is a generic training code routine written with 2 nested loops. This would be used as a showcase to illustrate how to implement your own loops.

Training routine

To make this code standardized, we will need to create 2 Loops. To implement a loop, you’ll need to implement the done property and override theadvance method to create a loop.

  1. We will start with the internal loop TrainingEpochLoop. We need to implement the following functionalities:
  • done: Return True when the batch_idx is larger or equal to the dataloader length.
  • reset: Create a new dataloader iterator.
  • advance: Perform a single step of optimization on the model.
https://gist.github.com/tchaton/8be2cb47232b14d7d650e1f989f58f8e

We can use TrainingEpochLoop we defined in our original snippet as follows:

2. To override the outer FitLoop, we implement the following:

  • done: Return True the number of the epoch is reached.
  • advance: Run the TrainingEpochLoop and increment the epoch counter.

3. Finally, as both the FitLoop and TrainingEpochLoop are implemented, you can now use the FitLoop to train our model on your dataloader:

Find a more complete example with MNIST and LightningLite here.

Learn More

To learn more about these examples and the new API, check out the documentation!

We are very excited to bring this new level of flexibility to the community and we are excited to see what you will build with it! If you need help or guidance, have questions, or looking for ways to contribute. Find us on slack!

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

--

--

Published in PyTorch Lightning Developer Blog

PyTorch Lightning is a lightweight machine learning framework that handles most of the engineering work, leaving you to focus on the science. Check it out: pytorchlightning.ai

Written by PyTorch Lightning team

We are the core contributors team developing PyTorch Lightning — the deep learning research framework to run complex models without the boilerplate