Skip to content

andrearosasco/ContinualFlame

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

49 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Build Status

ContinualFlame

Small lightweight package for Continual Learning in PyTorch.

Installation

For now the package is hosted on TestPyPi. To install it you just need to run:

pip install continual-flame

Usage

To use the package you just need to import it inside your project.

import contflame as cf

At the moment the package contains just the dataset module.

Dataset

This module contains datasets normally used in the continual learning scenario. The main ones are:

  • SplitMNIST - MNIST dataset split in classes. It allows to create different subtasks by including custom subsets of classes.
  • PermutedMNIST - permuted MNIST dataset. It allows to choose the shape of the applied permutation.
  • SplitCIFAR100
  • PermutedCIFAR100

Examples

SplitMNIST

In the following example the training tasks are five binary classification tasks on subsequent pairs of digit (i.e task 1 (0, 1), task 2 (2, 3), ...)

from cont_flame.dataset import SplitMNIST

valid = []
for i in range(1, 10, 2)
  train_dataset = SplitMNIST(classes=[i, i+1], dset='train', valid=0.2)
  valid.append(SplitMNIST(classes=[i, i+1], dset='valid', valid=0.2))
  
  for e in epochs:
    # train the model on train_dataset
    # ...
    
  for v in valid:
    # test the model on the current and the previous tasks
    # ...

PermutedMNIST

To get a random permutation set tile to (1, 1). The same random permutation, selected by the task id, will be applied to all the data points.

PermutedMNIST(tile=(1, 1), task=1)

You can also apply the permutation row (or column) wise by setting the corresponding dimension of the tile equal to the one of the image
PermutedMNIST(tile=(1, 28), task=1)

Or try to maintain high level spatial feature by setting a bigger tile.
PermutedMNIST(tile=(8, 8), task=1)

To get the images without any permutation set the tile to (28, 28) (default value).

About

Small package for Continual Learning in PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages