Skip to content

A summarization of Transformer-based architectures for CV tasks, including image classification, object detection, segmentation, and Few-shot Learning. Keep updated frequently.

Notifications You must be signed in to change notification settings

quanghuy0497/Transformers4Vision

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

60 Commits
 
 
 
 
 
 
 
 

Repository files navigation

About

This repository summaries Transformer-based architectures in the Computer Vision aspect, from the very basic (classification) to complex (object detection, segmentation, few-shot learning) tasks.

The main purpose of this list is to review and recap only the main approach/pipeline/architecture of these papers to capture the overview of transformers for vision, so other parts of the papers e.g. experimental performance, comparison results won't be presented. For a better intuition, please read the original article and code that are attached along with the recap sections. Of course, there might be some mistakes when reviewing these papers, so if there is something wrong or inaccurate, please feel free to tell me.

The paper summarization list will be updated frequently.

Table of contents

Standard Transformer

This section introduces original transformer architecture in NLP as well as its versions in Computer Vision, including ViT for image classification, VTN for video classification, and ViTGAN for the generative adversarial network. Finally, a deep comparision between ViT and ResNet is introduced, to see deep down if the anttention-based model is similar to the Conv-based model.

Basic Transformer

  • Paper: https://arxiv.org/pdf/1706.03762.pdf
  • Input:
    • Sequence embedding (e.g. word embeddings of a sentence)
    • Positional Encoding (PE) => encode the positions of embedding word within the sentence in the input of Encoder/Decoder block
      • Read here for detailed explanation of PE
  • Encoder:
    • Embedding words => Skip[MHSA => Norm] => Skip[FFN => Norm] => Patch encoding
      • MHSA: Multi-Head Self Attention
        • For the general explanation about MHSA and Transformer, you can look at here (note from the Deep Learning Specialization courses, taught by Andrew Ng)
          • You can also look at here (English) or here (Vienamese) for further explanation in detail about how Multi-Head Self Attention work.
      • FFN: FeedForward Neural Network
    • Repeat N times (N usually 6)
  • Decoder:
    • Decoder input:
      • Leaned output of the decoder (initial token in the begining, learned sentence throughout the process), shift right
      • Patch encoding (put in the middle of the Decoder)
    • (Input + Positional Encoding) => Skip[MHSA + Norm] => Skip[(+ Patch encoding) => MHSA => Norm] => Skip[FFN + Norm] => Linear => Softmax => Decoder Output
    • Using the decoder output as the input for next round, repeat N times (N ussually 6)
  • Computational complexity between Self-Attention; Conv and RNN:
  • Codes: https://github.com/SamLynnEvans/Transformer

ViT (Vision Transformer)

  • Paper: https://arxiv.org/pdf/2010.11929.pdf
  • Input:
    • Image [H, W, C] => non-overlapped patches (conventionally 16x16 patch size) => flatten into sequence => linear projection (vectorized + Linear) => patch embeddings
    • Positional encoding added to the patch embeddings for location information of the patchs sequence
    • Extra learnable [Cls] token (embedding) + positional 0 => attached on the head of the embedding sequence (denote as Z0)
  • Architecture: (Image Patches + Position Embedding) => Transformer Encoder => MLP Head for Classification
    • Transformer Encoder: Skip[Norm => MHSA] => Skip[Norm + MLP(Linear, GELU, Linear)] => output
    • MLP Head for classification: C0 (output of Z0 after went through the Transformer Encoder) => MLP Head (Linear + Softmax) => classified label
  • Good video explanation: https://www.youtube.com/watch?v=HZ4j_U3FC94
  • Code: https://github.com/lucidrains/vit-pytorch

VTN (Video Transformer Network)

  • Paper: https://arxiv.org/pdf/2102.00719.pdf
  • Based on Longformer - transformer-based model can process a long sequence of thousands of tokens
  • Pipeline: Quite similar to the standard ViT
    • The 2D spatial backbone f(x) can be replaced with any given backbone for 2D images => feature extraction
    • The temporal attention-based encoder can stack up more layers, more heads, or can be set to a different Transformers model that can process long sequences.
      • Note that a special classification token [CLS] is added in front of the feature sequence => final representation of the video => classification task head for video classifies
    • The classification head can be modified to facilitate different video-based tasks, i.e. temporal action localization
  • Code: https://github.com/bomri/SlowFast/tree/master/projects/vtn

ViTGAN (Vision Transformer GAN)

  • Paper: https://arxiv.org/pdf/2107.04589.pdf
  • Both the generator and the discrimiator are designed based on the stadard ViT, but with modifications
  • Architecture:
    • Generator: Input latent z => Mapping Network (MLP) => latent vector w => Affine transform A => Transformer Encoder => Fourier Encoding (sinusoidal/sine activation) => E_fou => 2-layer MPL => Generated Patches
      • Transformer Encoder:
        • Embedding position => Skip[SLN => MHSA] => Skip[SLN =>MLP] => output
        • SLN is called Self-modulated LayerNorm (as the modulation depends on no external information)
          • The SLN formula is describe within the paper (Equation 14)
        • Embedding E as initial input; A as middle inputs for the norm layers
    • Discriminator: the pipeline architecture is similar to the standard ViT model, but with several changes:
      • Adapt the overlapping patches at the begining (rather than the nonoverlapping ones)
      • Replace the dot product between Q and K with Euclidean (L2) distance in the Attention formula
      • Apply spectral normalization (read paper for more information)
  • Code: To be updated

Conv Block Attention Module

  • Paper: https://arxiv.org/pdf/1807.06521v2.pdf
  • Conv Block Attention Module (CBAM): A light weight and general attention-based module which can be used for FFN
  • Architecture:
    • Intermediate feature map => infer 1D channel attetion map and 2D spatial map => multiplied to the input feature map => adaptive feature refinement
    • Channel attention module: exploiting the inter-channel relationship of features
      • Feature map F => AvgPool || MaxPool => Share MLP [3 layers, ReLU] => element-wise summation => sigmoid => Channel attention Mc
    • Spatial attention module: utilizing the inter-spatial relationship of feature
      • Channel-refine feature F' => [AvgPool, MaxPool] => Conv7x7 => sigmoid => Spatial attention Ms
  • CBAM can be combine with ResBlock:
    • Conv => Skip_connection(CBAM) => Next Conv

Do Vision Transformer see like CNN?

  • Paper: https://arxiv.org/pdf/2108.08810.pdf
  • Representations Structural:
    • ViTs having highly similar representations throughout the model, while the ResNet models show much lower similarity between lower and higher layers.
      • ViT lower layers compute representations in a different way to lower layers in the ResNet.
      • ViT also more strongly propagates representations between lower and higher layers.
      • The highest layers of ViT have quite different representations to ResNet.
  • Local and Global Information in layer Representations:
    • ViTs have access to more global information than CNNs in their lower layers, leading to quantitatively different features than (computed by the local receptive fields in the lower layers of the) ResNet.
      • Even in the lowest layers of ViT, self-attention layers have a mix of local heads (small distances) and global heads (large distances) => in contrast to CNNs, which is hardcoded to attend only locally in the lower layers.
      • At higher layers, all self-attention heads are global.
    • Lower layer effective receptive fields for ViT are larger than in ResNets, and while ResNet effective receptive fields grow gradually, ViT receptive fields become much more global midway through the network.
  • Representation Propagation through Skip connections:
    • Skip connections in ViT are even more influential than in ResNet => strong effects on performance and representation similarity
      • Skip connections play a key role in the representational structure of ViT.
      • Skip connections play an key roles of propagating the representations throught out the ViT => uniform structure in the lower and higher layers.
  • Spatial Information and Localization:
    • Higher layers of ViT maintain spatial location information more faithfully than ResNets.
      • ViTs with CLS tokens show strong preservation of spatial information — promising for future uses in object detection.
      • When trained with global average pooling (GAP) instead of a CLS token, ViTs show less clear localization.
      • ResNet50 and ViT with GAP model tokens perform well at higher layers, while in the standard ViT trained with a CLS token the spatial tokens do poorly – likely because their representations remain spatially localized at higher layers, which makes global classification challenging.
  • Effects of Scale on Transfer Learning:
    • For larger models, the larger dataset is especially helpful in learning high-quality intermediate representations.
    • While lower layer representations have high similarity even with 10% of the data, higher layers and larger models require significantly more data to learn similar representations.
    • Larger ViT models develop significantly stronger intermediate representations through larger pre-training datasets than the ResNets.

Transformer Optimization

This section introduces techniques of training vision transformer-based model effectively with optimization methods (data, augmentation, regularization,...). As the Scaled Dot-Product Attention comes with quadratic complexity O(N^2), several approaches (Efficient Attention, Linformer) are introduced to reduce the computational complexity down to linear O(N).

How to train ViT?

  • Paper: https://arxiv.org/pdf/2106.10270.pdf
  • Experimental hyperparameters:
    • Pre-trained
      • Adam optimization with b1 = 0.9 and b2 = 0.999
      • Batch size 4096
      • Cosine learning rate with linear warmup 10k step
      • Gradient clipping at global norm 1
    • Fine-tune:
      • SGD optimization with momentum 0.9
      • Batch size of 512
      • Cosine decay learning rate schedule with a linear warmup
      • Gradient clipping at global norm 1
  • Regularization & augmentation:
    • By the judicious (wise) amount of regularization and image augmentation, one can (pre-)train a model to similar accuracy by increasing the dataset size by about an order of magnitude.
    • Deterioration in validation accuracy increase when using various amounts of augmentation (RandAugment, Mixup) and regularization (Dropout, StochasticDepth).
      • Generally speaking, there are significantly more cases where adding augmentation helps, than where adding regularization helps
      • For a relatively small amount of data, almost everything helps. But in large scale of data, almost everything hurt; only when also increasing computer, does augmentation help again
  • Transfer:
    • No matter how much training time is spent, it does not seem possible to train ViT models from scratch to reach accuracy anywhere near that of the transferred model. => transfer is the better option
    • Furthermore, since pre-trained models are feely to download, the pre-training cost for practitioners is effectively zero
    • Adapting only the best pre-trained model works equally to adapting all pre-trained models (and then selecting the best)
      • Then selecting a single pre-trained model based on the upstream score is a cost-effective practical strategy
  • Data:
    • More data yields more generic models => recommend that the design choice is using more data with a fixed compute budget
  • Patch-size:
    • Increasing patch size to shrinking model size
    • Using a larger patch-size (/32) significantly outperforms making the model thinner (/16)

Efficient Attention

  • Paper: https://arxiv.org/pdf/1812.01243v9.pdf
  • Efficient Attention:
    • Linear memory and computational complexity O(n)
    • Possess the same representational power as the convention dot-product attention
    • Actually, it comes with better performance than the convention attention
  • Method:
    • Initially, feature X => 3 linears => Q [n, k]; K [n, k]; V [n, d] with k and d are the dimensionalities of keys and input representation (or embedding).
    • The Dot-product Attetion is calculated by: D(Q,K,V) = softmax(Q*K^T)*V => scale with sqrt(k)
      • The Q*K^T (denoted Pairwise similarity S) have the shape [n, n] => S*V have the shape [n, d] => O(n^2.d)
    • The Efficient Attention is calculated by: E(Q,K,V) = softmax(Q)*softmax(K^T*V) => scales with sqrt(n)
      • p is the normalization
      • The K^T*V (denoted Global Context Vectors G) have the shape [k, d] with k & d are constants and can be determined => O(1)
      • Then, Q*G have the shape [n, d] => O(k.d.n) or O(n)
  • Dot-product Attetion and Efficient Attention are equivalence with each other with mathematic proof:
  • Explanation from the author: https://cmsflash.github.io/ai/2019/12/02/efficient-attention.html
  • Code:

Linformer

  • Paper: https://arxiv.org/pdf/2006.04768.pdf
  • The convention Scaled Dot-Product Attention is decomposed into multiple smaller attentions through linear projections, such that the combination of these operations forms a low-rank factorization of the original attention. Reduce the complexity to O(n) in time and space
  • Method:
    • Add two linear projection matrices Ei and Fi [n, k] when computing K & V
      • From K, V with shape [n, d] => Ei*K, Fi*V with shape [k, d]
    • Then, calculate the Scaled Dot-Product Attention as usual. The operation only requires O(n.k) time and space complexity.
      • If the projected dimension k >> n, then the complexity is O(n)
  • Code:

Longformer

  • Paper: https://arxiv.org/pdf/2004.05150.pdf
  • Longformer is developed with linear scale O(n) for long sequences processing in NLP, but it can also be applied for video processing
  • Method:
    • Sliding Window:
      • With an arbitrary window size w, each token in the sequence will only attend to some w tokens (mostly w/2 on each side) => the computation complexity is O(n.w)
      • With l layers of the transformer, the receptive field of the sliding window attention is [l x w]
    • Dilated Sliding Window:
      • To further increase the receptive field without increasing computation, the sliding window can be “dilated”, similar to the dilated CNNs
      • With the number of gaps between each token in the window d, the dilated attention has the dilation size of d.
      • Then, the receptive field of dilated sliding window attention is [l x d x w]
    • Global Attention (full self-attention):
      • The windowed and dilated attention are not flexible enough to learn task-specific representation
      • The global attention is added to few pre-selected input locations to tackle the problem.
    • Linear Projection for Global Attention:
      • 2 separate sets {Qs, Ks, Vs} and {Qg, Kg, Vg} was used for sliding window and global attention, respectively
      • This provides flexibility to model the different types of attention patterns
  • Code: https://github.com/allenai/longformer

Discussion

  • I wonder if we apply the row/column multiplication methods (read here for more details), does the computational complexity of matrix multiplication might reduce?
    • With A and B are [N, N] matrices, then the normal matrix multiplication has O(N^3) complexity
    • However, I believe with the row/column multiplication, the computation complexity might reduce to O(N^2):
      • Just a thought, maybe I'm wrong. Need to verify
    • Then again, the multiplication inside the Scaled Dot-Product Attention is between two embeddings [c, N], which has the complexity O(N^2), I do not think we can reduce the computational complexity with this simple row/column multiplication.
  • Another option is applying FFT (Fast Fourier Transform) to reduce the computation time
    • In fact, it reduces the complexity from O(N^3) down to O(N.logN)
    • But does it generalize well with the input embeddings of the Scaled Dot-Product Attention? Of course, the input embedding have to be normalized in prior, but what if we want to work with different shape of input i.e. high-resolution images?
  • Furthermore, there are more techniques to reduce Scaled Dot-Product Attention complexity that are presented in detailed in the below sections, including:
    • Matmul(K.T, V) instead of Matmul(Q, K.T): FANet
    • Subsample Q, K, or V: SegFormer, UTNet
    • Sultiply the horiontal and vertial separately, then combine later: AnchorDETR, SOTR
    • More techniques with correspondence complexity can be read in here

Classification Transformer

This section introduces trasformer-based models for image classification and its sub-tasks (image pairing or multi-label classification). Of course, the paper reviewed list will be updated frequently.

Instance-level Image Retrieval using Reranking Transformers

  • Paper: https://arxiv.org/pdf/2103.12236.pdf
  • Reranking Transformers (RRTs): lightweight small & effective model learn to predict the similarity of image pair directly based on global & local descriptor
  • Pipeline:
    • 2 Image X, X' => Global/Local feature discription => preparation => RRTs => z[cls] => Binary Classifier => Do X and X' represent the same object/scene?
    • Preparation:
      • Attach with 2 special tokiens at the head of X and X':
        • [ClS]: summarize signal from both image
        • [SEP]: to extra separator tokien (distinguise X and X')
      • Positonal encoding
    • Global/local representation: ResNet50 backbone; extra linear projector to reduce global descriptor dimension; L2 norm to unit norm
    • RRTs:
      • Same as the standard transformer layer: Skip[Input => MHSA] => Norm => MLP => Norm => ReLU; with 4 layers
      • 6 layers with 4 MSHA head
    • Binary Classifier:
      • Feature vector Z[Cls] from the last transformer layer as input
      • 1 Linear layer with sigmoid
      • Training with binary cross entropy
  • Code: https://github.com/uvavision/RerankingTransformer

General Multi-label Image Classification with Transformers

  • Paper: https://arxiv.org/pdf/2011.14027.pdf
  • C-Tran (Classificcation transformer): Transformer-based model for multi-label image classification that exploits dependencies among a target set of labels
    • Training: Image & Mask Randome Label => C-Tran => predict masked label
      • Learn to reconstruct a partial set of labels given randomly masked input label embedding
    • Inference: Image & Mask Everything => C-Tran => predit all labels
      • Predict a set of target labels by masking all the input labels as unknown
  • Architecture:
    • Image => ResNet 101 => feature embedding Z
    • Image => label embedding L (represent l possible label) => add with state embedding s (with 3 state unknow U, negative N, positive P)
    • Z & L + s => C-Tran => L' (predicted label embedding) => FFN => Y_hat (predicted label with posibility)
      • C-Tran:
        • Skip[MHSA => Norm] => Linear => ReLU => Linear
        • 3 transformer layers with 4 MHSA
      • FFN: label inference classifier with single linear layer & sigmoid activation
  • Label Mask Training:
    • During training:
      • Randomly mask a certain amount of label
        • Given L possilbe label => number of "unknown" (masked) labels 0.25L <= n <= L
      • Using groundtruth of the other labels (via state embedding) => predict masked label (with cross entropy loss)
  • Code: https://github.com/QData/C-Tran/

Object Detection/Segmentation Transformer

This section introduces several attention-based architectures for object detection (DETR, AnchorDETR,...) and segmentation tasks (MaskFormer, TransUNet...), even for a specific task such as hand detection (HandsFormer). These architectures majorly are the combination of Transformer and CNN backbone for these sophisticated tasks, but few are solely based on the Transformer architecture (FTN).

DETR (Detection Transformer)

  • Paper: https://arxiv.org/pdf/2005.12872.pdf
  • Transformer Encoder & Decoder:
    • Share the same architecture with the original transformer
    • Encoder:
      • Input sequence: flattened 2D feature (Image => CNN => flatten) + learnable fixed positional encoding (add to each layer)
      • Output: encoder output in sequence
    • Decoder:
      • Input: Object queries (learned positional embeddings) + encoder output (input in the middle)
      • Output: output embeddings
      • The Decoder decode N objects in parallel at each decoder layer, not sequence one element at a time
      • The model can reason about all objects together using pair-wise relations between them, while being able to use whole image as content
  • Prediction FFN:
    • 3-layer MLP with ReLU
    • Output embeddings as input
    • Predict normalized center coordinates, heigh and width of bounding box
  • Architecture: Image => Backbone (CNN) => 2D representation => Flatten (+ Positional encoding) => Transformer Encoder-Decoder => Prediction FFN => bounding box
  • Code: https://github.com/facebookresearch/detr

AnchorDETR

  • Paper: https://arxiv.org/pdf/2109.07107.pdf
  • Backbone: ResNet40
  • The encoder/decoder share the same structure as DETR
    • However, the self-attention in the encoder and decoder blocks are replaced by Row-Column Decouple Attention
  • Row-Column Decouple Attention:
    • Help reduce the GPU memeory when facing with high-resolution feature
    • Main idea:
      • Decouple key feature Kf into row feature Kf,x and column feature Kf,y by 1D global average pooling
      • Then perform the row attetion and column attentions separately
  • Code: https://github.com/megvii-research/AnchorDETR

MaskFormer

  • Paper: https://arxiv.org/pdf/2107.06278.pdf
  • Pixel-level module:
    • Image => Backbone (ResNet) => image feature F => Pixel decoder (upsampling) => per-pixel embedding E_pixel
  • Transformer module (Decoder only):
    • Standard Transformer decoder
    • N queries (learneable positional embeddings) + F (input in the middle) => Tranformer Decoder => N per-segment embeddings Q
    • Prediction in parallel (similar to DETR)
  • Segmentation module:
    • Q => MLP (2 Linears + solfmax) => N mask embeddings E_mask & N class predictions
    • Dot_product(E_mask, E_pixel) => sigmoid => Binary mask predictions
    • Matrix_mul(Mask predictions, class predictions) => Segmantic segmentation
  • Code: https://github.com/facebookresearch/MaskFormer

SegFormer

  • Paper: https://arxiv.org/pdf/2105.15203.pdf
  • Input: Image [H, W, 3] => patches of size 4x4 (rather than 16x16 like ViT)
    • Using smaller patches => favor the dense prediction task
    • Do not need positional encoding (PE):
      • Not necessary for semantic segmentation
      • The resolution of PS is fixed => needs to be interpolated when facing different test resolutions => dropped accuracy
  • Hierarchical Transformer Encoder: extract coarse and fine-grained features, partly inspired by ViT but optimized for semantic segmentation
    • Overlap patch embeddings => [Transformer Block => Downsampling] x 4 times => CNN-like multi-level feature map Fi
      • Feature map size: [H, W, 3] => F1 [H/4, W/4, C1] => F2 [H/8, W/8, C2] => ... => F4 [H/32, W/32, C4]
        • Provide both high and low-resolution features => boost the performance of semantic segmentation
      • Transformer Block1: Efficient Self-Atnn => Mix-FNN => Overlap Patch Merging
        • Efficient Self-Attention: Use the reduction ratio R on sequence length N = H x W (in particular, apply stride R) => reduce the complexity by R times
        • Mix-FFN: Skip[MLP => Conv3x3 => GELU => MLP] which considers the effect of zero padding to leak location information (rather than positional encoding)
        • Overlapped Patch Merging: similar to the image patch in ViT but overlap => combine feature patches
  • Lightweight All-MLP Decoder: fuse the multi-level features => predict semantic segmentation mask
    1. 1st Linear layer: unifying channel dimension of multi-level features Fi (from the encoder)
    2. Fi are upsampler to 1/4th and concat together
    3. 2nd Linear layer: fusing concatenated features F
    4. 3rd Linear layer: predicting segmentation mask M [H/4, W/4, N_cls] with F
  • Code: https://github.com/lucidrains/segformer-pytorch

Segmenter

  • Paper: https://arxiv.org/pdf/2105.05633.pdf
  • Architecture:
    • Image patches => Flatten & project => + Positional Encoding => Transformer Encoder => [Patch encoding + class embeddings] => Decoder (Mask Transformer) => scalar product => upsample & argmax => Segmentation Map
    • Encoder:
      • Input z0 + PE => Skip[Norm => MHSA] => Skip[Norm => 2 layers MLP] => Patch encoding Zl
    • Decoder:
      • Input:
        • Patch encoding Zl
        • K learnable class embeddings cls = [cls1,...cls_k] with K is the number of class. Each class embbedding is initial randomly & assigned to a single semantic class => generate the class mask
      • Mask Transformer:
        • Zl & cls => Transformer encoder => Z'm & c
        • Masks(Z'm, c) = L2_norm(Z'm).c^T => reshaped into 2D S_mask => bilinearly upsample to the original size [H, W, K] => softmax => final segmentation mask
  • Code: https://github.com/rstrudel/segmenter

Fully Transformer Networks

  • Paper: https://arxiv.org/pdf/2106.04108.pdf
  • Fully Transformer Networks for semantic image segmentation, without relying on CNN.
    • Both the encoder and decoder are composed of multiple transformer modules
    • Pyramid Group Transformers (PGT) encoder to divide feature maps into multiple spatial groups => compute the representation for each
      • Capable to handle spatial detail or local structure like CNN
      • Reduce unaffordable computational & memory cost of the standard ViT; reduce feature resolution and increase the receptive field for extracting hierarchical features
    • Feature Pyramid Transformer (FPT) decoder => fuse semantic-level & spatial level information from PGT encoders => high-resolution, high-level semantic output
  • Architecture:
    • Image => Patch => PGT Encoder => FPT Decoder => linear layer => bilinear upsampling => probability map => argmax(prob_map) => Segmentation
    • PGT: four hierarchical stages that generate features with multiple scales, include Patch Transform (non-overlapping) + PGT Block to to extract hierarchical representations
      • PGT Block: Skip[Norm => PG-MSA] => Skip[Norm => MLP]
      • PG-MSA (Pyramid-group transformer block): Head_ij = Attention(Qij, Kij, Vij) => hi = reshape(Head_ij) => PG-MSA = Concat(hi)
    • FPT: aggregate the information from multiple levels of PGT encoder => generate finer semantic image segmentation
      • The scale of FPT is not larger the better for segmentation (with limited segmentation training data) => determined by depth, embedding dim, and the reduction ratio of SR-MSA
      • SR-MSA (Spatial-reduction transformer block): reduce memory and computation cost by spatially reducing the number of Key & Value tokiens, especially for high-resoluton representations
      • The multi-level high-resolution feature of each branch => fusing (element-wise summation/channel-wise concatenation) => finer prediction
  • Code: To be updated

TransUNet

  • Paper: https://arxiv.org/pdf/2102.04306.pdf
  • Downsampling (Encoder): using CNN-Transformer Hybrid
    • (Medical) Image [H, W, C] => CNN => 2D feature map => Linear Projection (Flatten into 2D Patch embedding) => Downsampling => Tranformer => Hidden feature [n_patch, D]
      • CNN: downsampling by 1/2 => 1/4 => 1/8
      • Transformer: Norm layer before MHSA/FFN (rather than applying Norm layer after MHSA/FFN like the original Transformer), total 12 layers
    • Why using CNN-Transformer hybrid:
      • Leverages the intermediate high-resolution CNN feature maps in the Decoder
      • Performs better than the purge transformer
  • Upsamling (Decoder): using Cascaded Upsampler
    • Similar to the upsamling part of the standard UNet
      • Upsampling => concat with corresponded CNN feature map (from the Encoder) => Conv3x3 with ReLu
      • Segmentation head (Conv1x1) at the final layer
    • Hidden Feature [n_patch, D] => reshape [D, H/16, W/16] => [512, H/16, H/16] => [256, H/8, W/8] => [128, H/4, W/4] => [64, H/2, W/2] => [16, H, W] => Segmentation head => Segmantic Segmentation
  • Code: https://github.com/KenzaB27/TransUnet

UTNet (U-shape Transformer Networks)

  • Paper: https://arxiv.org/pdf/2107.00781.pdf
  • Pipeline:
    • Apply conv layers to extract local intensity feature, while using self-attention to capture long-range associative information
    • UTNet follows the standard design of UNet, but replace the last conv of the building block in each resolution (except the highest one) with the proposed Transformer module
    • Rather than using the convention MHSA like the standard Transformer, UTNet develops the Efficient MHSA (quite similar to the one in SegFormer):
      • Efficient MHSA: Sub-sample K and V into low-dimensional embedding (reduce size by 8) using Conv1x1 => bilinear interpolation
    • Using 2-dimensional relative position encoding by adding relative height and width information rather than the standard position encoding
  • Code: https://github.com/yhygao/UTNet

SOTR (Segmenting Objects with Transformer)

  • Paper: https://arxiv.org/pdf/2108.06747.pdf
  • Combines the advantages of CNN and Transformer
  • Architecture:
    • Pipeline:
      • Image => CNN Backbone => feature maps in multi-scale => patch recombination + positional embedding => clip-lvel feature sequences/blocks => Transformer => global-level semantic feature => functional heads => class & conv kernel prediction
      • Backbone output => Multi-level upsampling model (with dynamic conv) => dynamic conv(output, Kerner head) => instance masks
    • CNN Backbone: Feature pyramid network
    • Transformer: proposed 2 different transformer designs with Twin attention:
      • Twin attention: simplify the attention matrix with sparse representation (as the self-attention has both quadratic time and memory complicity) => higher computational cost
        • Calculate attention within each column (independent between columns) => calculate attention within each row (independent between rows) => connect together
        • Twin att. has a global receptive field & covers the information along 2 dimension
      • Transformer layer:
        • Pure twin layer: Skip[Norm => Twin Att.] => Skip[Norm => MLP]
        • Hybrid twin layer: Skip[Norm => Twin Att.] => Skip[Conv3x3 => Leaky ReLU => Conv3x3]
        • Hybrid Twin comes with the best performance
    • Multi-level upsampling model: P5 feature map + Positional from transformer + P2-P4 from FPN => [Conv3x3 => Group Norm => Relu, multi stage] => upsample x2, x4, x8 (for P3-P5) => added together => point-wise conv => upsamping => final HxW feature map
  • Code: https://github.com/easton-cau/SOTR

HandsFormer

  • Paper: https://arxiv.org/pdf/2104.14639.pdf
  • Architecute:
    • Image => UNet => Image features (from layers of UNet decoder) + Keypoint heatmap => bilinear interpolation & concat => FFN (3-layer MLP) => concat with keypoint heatmap (with positional encoding) => keypoint representation (likely to correspond to the 2D location of hand joints)
      • Localizing the joints of hands in 2D is more accurate than directly regressing 3D location.
      • The 2D keypoints are a very good starting point to predict an accurate 3D pose for both hands
    • Keypoint representation => Transformer encoder => FFN (2-layer MLP + linear projection) => Keypoint Identity predictor [2 FC layer => linear projection => Softmax] => 2D pose
    • [Joint queries, Transform encoder output] => Transform decoder => 2-layer MLP + linear projection => 3D pose
  • Code: To be updated

Unifying Global-Local Representations in Salient Object Detection with Transformer

  • Paper: https://arxiv.org/pdf/2108.02759.pdf
  • Jointly learn global and local features in a layer-wise manner => solving the salient object detection task with the help of Transformer
    • With the self-attention mechanism => transformer is capable to model the "contrast" => demonstrated to be crucial for saliency perception
  • Architecture:
    • Encoder:
      • Input => split to grid of fixed-size patches => linear projection => feature vector (represent local details) + positional encoding => Encoder => encode global features without diluting the local ones
      • The Encoder includes 17 layers of standard transformer encoder.
    • Decoder:
      • Decode the features with global-local information over the inputs and the previous layer from encoder by densely decode => preserve rich local & global features.
      • The density decoder contains various types of decoder blocks, including:
        • Naive Decoder: directly upsampling the outputs of last layer => same resolution of inputs => generating the saliency map. Specifically, 3 Conv-norm-ReLU are aplied => bilinearly upsample x16 => sigmoid
        • Stage-by-Stage Decoder: upsamples the resolution x2 in each stage => miltigate the losses of spatial details. Specifically, 4 stage x [3 Conv-norm-ReLU => sigmoid]
        • Multi-level Feature Decoder: sparsely fuse multi-level features, similar to the pyramid network. Specifically take feature F3, F6, F9, F12 (from the corresponed layers of encoder) => upsample x4 => several conv layers => fused => saliency maps
      • Density Decoder: integrate all encoder layer features => upsample to the same spatial resolution of input (include pixel suffle & bilinear upsampling x2) => concat => salient feature => Conv => sigmoid => saliency map
  • Code: https://github.com/OliverRensu/GLSTR

Real-time Semantic Segmentation with Fast Attention (FANet)

  • Paper: https://arxiv.org/pdf/2007.03815v2.pdf
  • FA Module:
    • FA module for non-local context aggregation for efficient segmentation.
    • Fast Attention = 1/n*L2_norm(Q)*[L2_norm(K^T)*V]
      • n = height x width
      • L2_norm to ensure the affinity is between [-1, 1]
      • Computational complexity O(n.c^2)
    • n/c times more efficient than the standard self-attention, given n>>c
  • Fast Attention Network (FANet) for real-time semantic segmentation:
    • Encoder:
      • Extract features from image input at different semantic levels
      • Lightweight backbone ResNet-18/34 without last FC layer
        • The first Res-block "Res-1" produces feature map of [h/4, w/4]
        • The other subsequence blocks output feature maps with resolution downsampled by a factor of 2
      • Apply Context Aggregation (FA module) at each stage
      • Down-sapmpling Res-4 and Res4 for spatial reduction
    • Context Aggregation:
      • Basically, the FA module, composed of 3 Conv1x1 layers (without ReLU) for embedding input feature to be {Q, K, V}
    • Decoder:
      • Merges and upsamples the features (outputs of FA module) in sequence from deep feature maps to shallow ones
      • Skip connection to connect middle features => enhanced decoded feature with high-level content
      • Output with [h/4, w/4]
  • FA module extending for Video Semantic Segmentation:
    • Extend FA module to spatial-temporal contexts => improves video semantic segmentation without increasing computational cost
    • At frame t, the FA module can be calculated (through several transform steps) by:
      • FAt = 1/n*L2_norm(Qt)*[F(Kt, Vt) + sum(F(Kt-i, Vt-i))]
        • Where F(K, T) = L2_norm(K^T)*V
      • However, the F(Kt-i, Vt-i) have already been computed in the prior steps and can be simply reused again for this step
      • Therefore, the spatial-temporal fast attention still maintains the complexity of O(n.c^2), which is not only fast but also free of t
    • Then, the normal FA module can be replaced by this spatial-temporal version for video semantic segmentation without increasing computational cost
  • Code: https://github.com/feinanshan/FANet

Few-shot Transformer

This section introduces transformer-based architecture for few-shot learning, mainly for but not strictly to the object detection and segmentation area. Overall, these pipeline architectures are quite complex, so I recommend you should read the paper along with these reviewes for better understanding. Then again, this list will be updated frequently.

Meta-DETR: Image-Level Few-Shot Object Detection with Inter-Class Correlation Exploitation

  • Paper: https://arxiv.org/pdf/2103.11731.pdf
  • Employ Deformable DETR and original Transformers as basic detection framework
  • Architecture:
    • Query image, Support images => Feature Extractor(ResNet-101) => Positional Encoding => Query/Support features => Correlational Aggregation Module => Support-Aggregated Query Feature => Transformer Encoder-Decoder => Prediction Head => Few-shot detection
  • Correlation Aggregation Module (CAM):
    • Key-compoment in Meta-DETR => aggregates query features with support classes => class-agnostic prediction
      • Can aggregate multiple support classes simultaneously => capture inter-class correlations => reduce misclassification, enhance generalization
    • Pipeline:
      • Query & Support features => MHSA => ROIAlign + Average pooling (on the support feature only) => Query feature map Q & Support prototypes S
      • S = Concat(S, BG-Prototype); Task Encodings T = Concat(T, BG-Encoding)
      • {Q, S, T} => Feature maching & Encoding matching (in parallel) => FFN => Support-Aggregated Query Features
    • Feature matching:
      • {Q, Sigmoid(S), S} => Single-Head Attention => Element-wise multiplication => feature matching output Qf`
    • Encoding Matching:
      • {T, Q, S} => Single-Head Attention => encoding output Qe
    • FFN: Element-wise Add(Qf, Qe) => FFN => Support-Aggregated Query features
  • Transformer Encoder-Decoder:
    • Follow the Deformable DETR architecture with 6 layer
      • Adapt Multi-scale deformable attention, with the CAM is counted as one encoder layer
    • Support-Aggregated Query features => Encoder-Decoder => Embedding E
  • Prediction Head:
    • E => FC, Sigmoid => Confidence score
    • E => FC, ReLU + FC, ReLU + FC, Sigmoid => Bounding Box
  • Code: https://github.com/ZhangGongjie/Meta-DETR

Boosting Few-shot Semantic Segmentation with Transformers

  • Paper: https://arxiv.org/pdf/2108.02266.pdf
  • Architecture:
    • Support/Query Images => Backbone (VGG/ResNet) => PANet => PFENet => X
      • PANet: Computer similarity between query features and support prototypes
      • PFENet: Concat(Query features; Expanded Support Prototypes; Prior Mask)
    • X => Multi-scale Processing => Global/Local Enhancement Module (with 2 ouput T and Z respectively) => Concat(T, Z) => MLP => Few-shot Semantic Segmentation
  • Multi-scale Processing:
    • Information over different scales (of the input feature maps X from support/query images) can be utilized
    • Multi-scaling with Global Average Pooling => feature Pyramid Xi = {X1, X2,...,Xn}
  • Global Enhancement Module (GEM):
    • Using Transformer => enhance the feature to exploit the global information
    • Xi => FC layer (for channel reduction) => X'i => Feature Merging Unit (FMU) => Yi
      • FMU:
        • Yi = X'i if i=1;
        • Yi = (Conv1x1(Concat(X'i, Ti-1)) + X'i) if i>1
    • Yi => [MHSA => MLP (2 Linear)] => MHSA => Ti => T
      • MHSA with GELU and Norm
      • [MHSA => MLP] repeat L times with L = 3
      • T = = Concat(T1, T2,...Tn) at layer L
  • Local Enhancement Module (LEM):
    • Follow the same pipeline as GEM:
      • Xi => FC layer => FMU => Yi => Conv => output
    • Rather than using transformer in GEM, LEM using Convolutional => encode the local information
    • LEM output: {Z1, Z2,...,Zn}
  • Segmentation Mask Prediction:
    • Local/Global output: Z = Concat(Z1, T2,...Tn)
    • Z => MLP => target mask M
  • Code: https://github.com/GuoleiSun/TRFS

Few-Shot Segmentation via Cycle-Consistent Transformer

  • Paper: https://arxiv.org/pdf/2106.02320.pdf
  • Architecture:
    • Query & Support Images => Backbone (ResNet) => Image features Xq and Xs => Concat the mask averaged support feature to Xq and Xs => Flatten into 1D sequence by Conv1x1 (with shape HW x D) => (CyC-Transformer, L times) => reshaping => Conv-head => segmentation mask
    • Xq and Xs has token represented by feature z at on pixel location => beneficial for segmentation
  • CyC-Transformer:
    • Self-Alignment block:
      • Just like the original Transformer encoder
      • Flatten query feature as input (query only)
      • Pixel-wise feature of query images => aggregate their global context information
    • Cross-alignment block:
      • Replace MHSA with CyC-MHA
      • Flatten query feature and sample of support feature as input
      • Performs attention between query and support pixel-wise features => aggregate relevant support feature into query ones
    • Cycle-Consistent Multi-head Attention (CyC-MHA):
      • Alleviate the excessive harmful support features that confuse pure pixel-level attention
      • Pipeline:
        • Affinity map A is calculated => measure the correspondence relationship between all query and support tokens
        • For a single token position j (j={0,...Ns}), its most similar point i* = argmax A(i,j) with i={0,...HqWq} is the index of flatten query feature
        • Construct Cycle-consistency (CyC) relationship for all tokens in the support sequence
          • The cycle-consistency help avoids being bias by possible harmful feature effectively (facing when training for few-shot segmentation)
        • Finally, CyC-MHA = softmax(Ai + B)V
          • Where B (with only 2 value -inf and 0) is the additive bias element-wise added to aggregate support feature and V is the value sequence
      • With B, the attention weight tends to be zero => irrelevant information will not be consider
      • CyC encourages the consistency between most relative features between query and support => produce consistent feature representation
  • Conv-head:
    • Output of CyC-MHA => reshaping to spatial dimensions => Conv-Head (Cov3x3 => ReLu => Conv1x1) => Segmentation Mask
  • Code: To be updated

Few-shot Semantic Segmentation with Classifier Weight Transformer

  • Paper: https://arxiv.org/pdf/2108.03032.pdf
  • Architecture:
    • First stage: pre-train encoder/decoder (PSPNet pre-trained on ImageNet) with supervised learning => stronger representation
      • Support/Query Image => Encoder-Decoder (PSPNet) => Linear clasifier with Support Mask (Support only) => Classifier (Support only) => [Classifier weight Q, Query feature K, Query feature V]
    • Second stage: meta-train the Classifier Weight Transformer (CWT) only (as the encoder-decoder capapble to capture generalization of unseen class)
      • {Q, K, V} => Skip[Linear => MHSA => Norm] => Conv operation with Q => Prediction Mask
  • Code: https://github.com/zhiheLu/CWT-for-FSS

Few-shot Transformation of Common Actions into Time and Space

  • Paper: https://openaccess.thecvf.com/content/CVPR2021/papers/Yang_Few-Shot_Transformation_of_Common_Actions_Into_Time_and_Space_CVPR_2021_paper.pdf
  • The goal is localize the spatio-temporal tubelet of an action in an untrimmed query video based on the common action in the trimment support video
  • Architecture:
    • Pipeline:
      • Untrimmed query video => split into clips + few support video => Video feature extractor (I3D) => spatio-temporal representation => Common Attention block => aligned with previous clip feature => query clip feature
      • Query clip feature => Few-shot Transformer (FST) => fuse the support features into the query clip feature => aggregating with input embedding
      • Top of output embedding (from FST) => Precition network => final tubelet prediction
    • Video feature extractor:
      • Adapt I3D as backbone to obtain spatio-temporal representation of a single query video & few support videos
      • Support video => feed the whole video to the backbone directly
      • Untrimmed query video => split into multiple clips => backbone network
      • Common attention block: built on self-attention mechanism & non-local structural => model long-term spatio-temporal information
        • Formula: A^(I1,I2) = I1 + Linear(Norm[A(I1,I2)])
          • With A^ is the common attention, A is the standard attention, I1, I2 are 2 inputs
        • Common attention aligns each query clip feature with its previous clip features => contain more motion information => benefit the common action localization
    • Few-shot Transformer (FST):
      • Encoder: standard architecture with MHSA. The input is supplied with fixed spatio-temporal positional encoding.
        • Support branch: the support video => encoder one by one => concat => Es => decoder (along with Eq)
        • Quary branch: enhanced query clip => encoder => Eq => decoder
        • The FFN can be a Conv1x1
      • Decoder: 3 input [Es, Eq, input embedding (learnt positional encoding)] => Common attention/MHSA => MHSA => FFN => Add&Norm => output embedding
    • Prediction network: output embedding => 3-layer FFN with ReLU => linear projection => normalized center coordinates => final action tubes for the whole untrimmed query video
  • Code: To be updated

A Universal Representation Transformer Layer for Few-Shot Image Classification

  • Paper: https://arxiv.org/pdf/2006.11702.pdf
  • URT layer is inspired from the standard Transformer network to effectively integrate the feature representations from the diverse set of training domains
    • Uses an attention mechanism to learn to retrieve or blend the appropriate backbones to use for each task
    • Training URT layer across few-shot tasks from many domains => support transfer across these tasks
  • Architecture:
    • ri(x) is the output vector from the backbone for domain i => r(x) = concat[r1(x),...rm(x)] on m pre-trained backbones
    • representation of Support set class r(Sc) = sum[r(x)]/|Sc|
    • For each class c, query Qc = Linear(r(Sc)) with weight Wc, bias bc; For each domain i and class c, key Kic = Linear(ri(Sc)) with Wk, bk
    • Then, with Qc, Kic, the scale dot-product attention Aic is calculated as usual
    • The adapted representation for the support & query set examples is compute by O(x) = Sum(Ai*ri(x)
    • Finally, the multi-head URT O(X) is the concatenation of all O(x), just like usual.
  • Code: https://github.com/liulu112601/URT

Resources





These notes were created by quanghuy0497@2021

About

A summarization of Transformer-based architectures for CV tasks, including image classification, object detection, segmentation, and Few-shot Learning. Keep updated frequently.

Topics

Resources

Stars

Watchers

Forks