Skip to content

apple/ml-capsules-inverted-attention-routing

Repository files navigation

Python 3.6

Capsules with Inverted Dot-Product Attention Routing

Pytorch implementation for Capsules with Inverted Dot-Product Attention Routing.

Paper

Capsules with Inverted Dot-Product Attention Routing
Yao-Hung Hubert Tsai, Nitish Srivastava, Hanlin Goh, and Ruslan Salakhutdinov
International Conference on Learning Representations (ICLR), 2020.

Please cite our paper if you find our work useful for your research:

@inproceedings{tsai2020Capsules,
  title={Capsules with Inverted Dot-Product Attention Routing},
  author={Tsai, Yao-Hung Hubert and Srivastava, Nitish and Goh, Hanlin and Salakhutdinov, Ruslan},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2020},
}

Overview

Overall Architecture

An example of our proposed architecture is shown above. The backbone is a standard feed-forward convolutional neural network. The features extracted from this network are fed through another convolutional layer. At each spatial location, groups of 16 channels are made to create capsules (we assume a 16-dimensional pose in a capsule). LayerNorm is then applied across the 16 channels to obtain the primary capsules. This is followed by two convolutional capsule layers, and then by two fully-connected capsule layers. In the last capsule layer, each capsule corresponds to a class. These capsules are then used to compute logits that feed into a softmax to computed the classification probabilities. Inference in this network requires a feed-forward pass up to the primary capsules. After this, our proposed routing mechanism (discussed later) takes over.

Inverted Dot-Product Attention Routing

In our method, the routing procedure resembles an inverted attention mechanism, where dot products are used to measure agreement. Specifically, the higher-level (parent) units compete for the attention of the lower-level (child) units, instead of the other way around, which is commonly used in attention models. Hence, the routing probability directly depends on the agreement between the parent’s pose (from the previous iteration step) and the child’s vote for the parent’s pose (in the current iteration step). We (1) use Layer Normalization (Ba et al., 2016) as normalization, and we (2) perform inference of the latent capsule states and routing probabilities jointly across multiple capsule layers (instead of doing it layer-wise).

Concurrent Routing

The concurrent routing is a parallel-in-time routing procedure for all capsules layers.

Usage

Prerequisites

Datasets

We use CIFAR10 and CIFAR100.

Run the Code

Arguments

Args Value help
debug - Enter into a debug mode, which means no models and results will be saved. True or False
num_routing 1 The number of routing iteration. The number should > 1.
dataset CIFAR10 Choice of the dataset. CIFAR10 or CIFAR100.
backbone resnet Choice of the backbone. simple or resnet.
config_path ./configs/resnet_backbone_CIFAR10.json Configurations for capsule layers.

Running CIFAR-10

python main_capsule.py --num_routing 2 --dataset CIFAR10 --backbone resnet --config_path ./configs/resnet_backbone_CIFAR10.json 

When num_routing is 1, the average performance we obtained is 94.73%.

When num_routing is 2, the average performance we obtained is 94.85% and the best model we obtained is 95.14%.

Running CIFAR-100

python main_capsule.py --num_routing 2 --dataset CIFAR100 --backbone resnet --config_path ./configs/resnet_backbone_CIFAR100.json 

When num_routing is 1, the average performance we obtained is 76.02%.

When num_routing is 2, the average performance we obtained is 76.27% and the best model we obtained is 78.02%.

License

This code is released under the LICENSE terms.

About

No description or website provided.

Topics

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages