This repository summaries Transformer-based architectures in the Computer Vision aspect, from the very basic (classification) to complex (object detection, segmentation, few-shot learning) tasks.
The main purpose of this list is to review and recap only the main approach/pipeline/architecture of these papers to capture the overview of transformers for vision, so other parts of the papers e.g. experimental performance, comparison results won't be presented. For a better intuition, please read the original article and code that are attached along with the recap sections. Of course, there might be some mistakes when reviewing these papers, so if there is something wrong or inaccurate, please feel free to tell me.
The paper summarization list will be updated frequently.
- Standard Transformer
- Transformer Optimization
- Classification Transformer
- Object Detection/Segmentation Transformer
- DETR (Detection Transformer)
- AnchorDETR
- MaskFormer
- SegFormer
- Segmenter
- Fully Transformer Networks
- TransUNet
- UTNet (U-shape Transformer Networks)
- SOTR (Segmenting Objects with Transformer)
- HandsFormer
- Unifying Global-Local Representations in Salient Object Detection with Transformer
- Real-time Semantic Segmentation with Fast Attention (FANet)
- Few-shot Transformer
- Meta-DETR: Image-Level Few-Shot Object Detection with Inter-Class Correlation Exploitation
- Boosting Few-shot Semantic Segmentation with Transformers
- Few-Shot Segmentation via Cycle-Consistent Transformer
- Few-shot Semantic Segmentation with Classifier Weight Transformer
- Few-shot Transformation of Common Actions into Time and Space
- A Universal Representation Transformer Layer for Few-Shot Image Classification
- Resources
This section introduces original transformer architecture in NLP as well as its versions in Computer Vision, including ViT for image classification, VTN for video classification, and ViTGAN for the generative adversarial network. Finally, a deep comparision between ViT and ResNet is introduced, to see deep down if the anttention-based model is similar to the Conv-based model.
- Paper: https://arxiv.org/pdf/1706.03762.pdf
- Input:
- Sequence embedding (e.g. word embeddings of a sentence)
- Positional Encoding (PE) => encode the positions of embedding word within the sentence in the input of Encoder/Decoder block
- Read here for detailed explanation of PE
- Encoder:
- Embedding words => Skip[MHSA => Norm] => Skip[FFN => Norm] => Patch encoding
- MHSA: Multi-Head Self Attention
- FFN: FeedForward Neural Network
- Repeat N times (N usually 6)
- Embedding words => Skip[MHSA => Norm] => Skip[FFN => Norm] => Patch encoding
- Decoder:
- Decoder input:
- Leaned output of the decoder (initial token in the begining, learned sentence throughout the process), shift right
- Patch encoding (put in the middle of the Decoder)
- (Input + Positional Encoding) => Skip[MHSA + Norm] => Skip[(+ Patch encoding) => MHSA => Norm] => Skip[FFN + Norm] => Linear => Softmax => Decoder Output
- Using the decoder output as the input for next round, repeat N times (N ussually 6)
- Decoder input:
- Computational complexity between Self-Attention; Conv and RNN:
- Codes: https://github.com/SamLynnEvans/Transformer
- Paper: https://arxiv.org/pdf/2010.11929.pdf
- Input:
- Image [H, W, C] => non-overlapped patches (conventionally 16x16 patch size) => flatten into sequence => linear projection (vectorized + Linear) => patch embeddings
- Positional encoding added to the patch embeddings for location information of the patchs sequence
- Extra learnable
[Cls]
token (embedding) + positional 0 => attached on the head of the embedding sequence (denote asZ0
)
- Architecture: (Image Patches + Position Embedding) => Transformer Encoder => MLP Head for Classification
- Transformer Encoder: Skip[Norm => MHSA] => Skip[Norm + MLP(Linear, GELU, Linear)] => output
- MLP Head for classification:
C0
(output ofZ0
after went through the Transformer Encoder) => MLP Head (Linear + Softmax) => classified label
- Good video explanation: https://www.youtube.com/watch?v=HZ4j_U3FC94
- Code: https://github.com/lucidrains/vit-pytorch
- Paper: https://arxiv.org/pdf/2102.00719.pdf
- Based on Longformer - transformer-based model can process a long sequence of thousands of tokens
- Pipeline: Quite similar to the standard ViT
- The 2D spatial backbone
f(x)
can be replaced with any given backbone for 2D images => feature extraction - The temporal attention-based encoder can stack up more layers, more heads, or can be set to a different Transformers model that can process long sequences.
- Note that a special classification token
[CLS]
is added in front of the feature sequence => final representation of the video => classification task head for video classifies
- Note that a special classification token
- The classification head can be modified to facilitate different video-based tasks, i.e. temporal action localization
- The 2D spatial backbone
- Code: https://github.com/bomri/SlowFast/tree/master/projects/vtn
- Paper: https://arxiv.org/pdf/2107.04589.pdf
- Both the generator and the discrimiator are designed based on the stadard ViT, but with modifications
- Architecture:
- Generator: Input latent
z
=> Mapping Network (MLP) => latent vectorw
=> Affine transformA
=> Transformer Encoder => Fourier Encoding (sinusoidal/sine activation) =>E_fou
=> 2-layer MPL => Generated Patches- Transformer Encoder:
- Embedding position => Skip[SLN => MHSA] => Skip[SLN =>MLP] => output
- SLN is called Self-modulated LayerNorm (as the modulation depends on no external information)
- The SLN formula is describe within the paper (Equation 14)
- Embedding
E
as initial input;A
as middle inputs for the norm layers
- Transformer Encoder:
- Discriminator: the pipeline architecture is similar to the standard ViT model, but with several changes:
- Adapt the overlapping patches at the begining (rather than the nonoverlapping ones)
- Replace the dot product between
Q
andK
with Euclidean (L2) distance in the Attention formula - Apply spectral normalization (read paper for more information)
- Generator: Input latent
- Code: To be updated
- Paper: https://arxiv.org/pdf/1807.06521v2.pdf
- Conv Block Attention Module (CBAM): A light weight and general attention-based module which can be used for FFN
- Architecture:
- Intermediate feature map => infer 1D channel attetion map and 2D spatial map => multiplied to the input feature map => adaptive feature refinement
- Channel attention module: exploiting the inter-channel relationship of features
- Feature map F => AvgPool || MaxPool => Share MLP [3 layers, ReLU] => element-wise summation => sigmoid => Channel attention Mc
- Spatial attention module: utilizing the inter-spatial relationship of feature
- Channel-refine feature F' => [AvgPool, MaxPool] => Conv7x7 => sigmoid => Spatial attention Ms
- Intermediate feature map => infer 1D channel attetion map and 2D spatial map => multiplied to the input feature map => adaptive feature refinement
- CBAM can be combine with ResBlock:
- Conv => Skip_connection(CBAM) => Next Conv
- Paper: https://arxiv.org/pdf/2108.08810.pdf
- Representations Structural:
- ViTs having highly similar representations throughout the model, while the ResNet models show much lower similarity between lower and higher layers.
- Local and Global Information in layer Representations:
- ViTs have access to more global information than CNNs in their lower layers, leading to quantitatively different features than (computed by the local receptive fields in the lower layers of the) ResNet.
- Even in the lowest layers of ViT, self-attention layers have a mix of local heads (small distances) and global heads (large distances) => in contrast to CNNs, which is hardcoded to attend only locally in the lower layers.
- At higher layers, all self-attention heads are global.
- Lower layer effective receptive fields for ViT are larger than in ResNets, and while ResNet effective receptive fields grow gradually, ViT receptive fields become much more global midway through the network.
- ViTs have access to more global information than CNNs in their lower layers, leading to quantitatively different features than (computed by the local receptive fields in the lower layers of the) ResNet.
- Representation Propagation through Skip connections:
- Skip connections in ViT are even more influential than in ResNet => strong effects on performance and representation similarity
- Spatial Information and Localization:
- Higher layers of ViT maintain spatial location information more faithfully than ResNets.
- ViTs with CLS tokens show strong preservation of spatial information — promising for future uses in object detection.
- When trained with global average pooling (GAP) instead of a CLS token, ViTs show less clear localization.
- ResNet50 and ViT with GAP model tokens perform well at higher layers, while in the standard ViT trained with a CLS token the spatial tokens do poorly – likely because their representations remain spatially localized at higher layers, which makes global classification challenging.
- ViTs with CLS tokens show strong preservation of spatial information — promising for future uses in object detection.
- Higher layers of ViT maintain spatial location information more faithfully than ResNets.
- Effects of Scale on Transfer Learning:
- For larger models, the larger dataset is especially helpful in learning high-quality intermediate representations.
- While lower layer representations have high similarity even with 10% of the data, higher layers and larger models require significantly more data to learn similar representations.
- Larger ViT models develop significantly stronger intermediate representations through larger pre-training datasets than the ResNets.
This section introduces techniques of training vision transformer-based model effectively with optimization methods (data, augmentation, regularization,...). As the Scaled Dot-Product Attention comes with quadratic complexity O(N^2), several approaches (Efficient Attention, Linformer) are introduced to reduce the computational complexity down to linear O(N).
- Paper: https://arxiv.org/pdf/2106.10270.pdf
- Experimental hyperparameters:
- Pre-trained
- Adam optimization with b1 = 0.9 and b2 = 0.999
- Batch size 4096
- Cosine learning rate with linear warmup 10k step
- Gradient clipping at global norm 1
- Fine-tune:
- SGD optimization with momentum 0.9
- Batch size of 512
- Cosine decay learning rate schedule with a linear warmup
- Gradient clipping at global norm 1
- Pre-trained
- Regularization & augmentation:
- By the judicious (wise) amount of regularization and image augmentation, one can (pre-)train a model to similar accuracy by increasing the dataset size by about an order of magnitude.
- Deterioration in validation accuracy increase when using various amounts of augmentation (RandAugment, Mixup) and regularization (Dropout, StochasticDepth).
- Generally speaking, there are significantly more cases where adding augmentation helps, than where adding regularization helps
- For a relatively small amount of data, almost everything helps. But in large scale of data, almost everything hurt; only when also increasing computer, does augmentation help again
- Transfer:
- No matter how much training time is spent, it does not seem possible to train ViT models from scratch to reach accuracy anywhere near that of the transferred model. => transfer is the better option
- Furthermore, since pre-trained models are feely to download, the pre-training cost for practitioners is effectively zero
- Adapting only the best pre-trained model works equally to adapting all pre-trained models (and then selecting the best)
- Then selecting a single pre-trained model based on the upstream score is a cost-effective practical strategy
- Data:
- More data yields more generic models => recommend that the design choice is using more data with a fixed compute budget
- Patch-size:
- Increasing patch size to shrinking model size
- Using a larger patch-size (/32) significantly outperforms making the model thinner (/16)
- Paper: https://arxiv.org/pdf/1812.01243v9.pdf
- Efficient Attention:
- Linear memory and computational complexity O(n)
- Possess the same representational power as the convention dot-product attention
- Actually, it comes with better performance than the convention attention
- Method:
- Initially, feature X => 3 linears =>
Q
[n, k];K
[n, k];V
[n, d] withk
andd
are the dimensionalities of keys and input representation (or embedding). - The Dot-product Attetion is calculated by:
D(Q,K,V) = softmax(Q*K^T)*V
=> scale withsqrt(k)
- The
Q*K^T
(denoted Pairwise similarityS
) have the shape [n, n] =>S*V
have the shape [n, d] => O(n^2.d)
- The
- The Efficient Attention is calculated by:
E(Q,K,V) = softmax(Q)*softmax(K^T*V)
=> scales withsqrt(n)
- p is the normalization
- The
K^T*V
(denoted Global Context VectorsG
) have the shape [k, d] withk
&d
are constants and can be determined => O(1) - Then,
Q*G
have the shape [n, d] => O(k.d.n) or O(n)
- Initially, feature X => 3 linears =>
- Dot-product Attetion and Efficient Attention are equivalence with each other with mathematic proof:
- Explanation from the author: https://cmsflash.github.io/ai/2019/12/02/efficient-attention.html
- Code:
- Paper: https://arxiv.org/pdf/2006.04768.pdf
- The convention Scaled Dot-Product Attention is decomposed into multiple smaller attentions through linear projections, such that the combination of these operations forms a low-rank factorization of the original attention. Reduce the complexity to O(n) in time and space
- Method:
- Add two linear projection matrices
Ei
andFi
[n, k] when computingK
&V
- From
K
,V
with shape [n, d] =>Ei*K
,Fi*V
with shape [k, d]
- From
- Then, calculate the Scaled Dot-Product Attention as usual. The operation only requires O(n.k) time and space complexity.
- If the projected dimension k >> n, then the complexity is O(n)
- Add two linear projection matrices
- Code:
- Paper: https://arxiv.org/pdf/2004.05150.pdf
- Longformer is developed with linear scale O(n) for long sequences processing in NLP, but it can also be applied for video processing
- Method:
- Sliding Window:
- With an arbitrary window size
w
, each token in the sequence will only attend to somew
tokens (mostlyw/2
on each side) => the computation complexity is O(n.w) - With
l
layers of the transformer, the receptive field of the sliding window attention is [l x w]
- With an arbitrary window size
- Dilated Sliding Window:
- To further increase the receptive field without increasing computation, the sliding window can be “dilated”, similar to the dilated CNNs
- With the number of gaps between each token in the window
d
, the dilated attention has the dilation size ofd
. - Then, the receptive field of dilated sliding window attention is [l x d x w]
- Global Attention (full self-attention):
- The windowed and dilated attention are not flexible enough to learn task-specific representation
- The global attention is added to few pre-selected input locations to tackle the problem.
- Linear Projection for Global Attention:
- 2 separate sets {Qs, Ks, Vs} and {Qg, Kg, Vg} was used for sliding window and global attention, respectively
- This provides flexibility to model the different types of attention patterns
- Sliding Window:
- Code: https://github.com/allenai/longformer
- I wonder if we apply the row/column multiplication methods (read here for more details), does the computational complexity of matrix multiplication might reduce?
- With A and B are [N, N] matrices, then the normal matrix multiplication has O(N^3) complexity
- However, I believe with the row/column multiplication, the computation complexity might reduce to O(N^2):
- Just a thought, maybe I'm wrong. Need to verify
- Then again, the multiplication inside the Scaled Dot-Product Attention is between two embeddings [c, N], which has the complexity O(N^2), I do not think we can reduce the computational complexity with this simple row/column multiplication.
- Another option is applying FFT (Fast Fourier Transform) to reduce the computation time
- In fact, it reduces the complexity from O(N^3) down to O(N.logN)
- But does it generalize well with the input embeddings of the Scaled Dot-Product Attention? Of course, the input embedding have to be normalized in prior, but what if we want to work with different shape of input i.e. high-resolution images?
- Furthermore, there are more techniques to reduce Scaled Dot-Product Attention complexity that are presented in detailed in the below sections, including:
This section introduces trasformer-based models for image classification and its sub-tasks (image pairing or multi-label classification). Of course, the paper reviewed list will be updated frequently.
- Paper: https://arxiv.org/pdf/2103.12236.pdf
- Reranking Transformers (RRTs): lightweight small & effective model learn to predict the similarity of image pair directly based on global & local descriptor
- Pipeline:
- 2 Image X, X' => Global/Local feature discription => preparation => RRTs => z[cls] => Binary Classifier => Do X and X' represent the same object/scene?
- Preparation:
- Attach with 2 special tokiens at the head of X and X':
- [ClS]: summarize signal from both image
- [SEP]: to extra separator tokien (distinguise X and X')
- Positonal encoding
- Attach with 2 special tokiens at the head of X and X':
- Global/local representation: ResNet50 backbone; extra linear projector to reduce global descriptor dimension; L2 norm to unit norm
- RRTs:
- Same as the standard transformer layer: Skip[Input => MHSA] => Norm => MLP => Norm => ReLU; with 4 layers
- 6 layers with 4 MSHA head
- Binary Classifier:
- Feature vector Z[Cls] from the last transformer layer as input
- 1 Linear layer with sigmoid
- Training with binary cross entropy
- Code: https://github.com/uvavision/RerankingTransformer
- Paper: https://arxiv.org/pdf/2011.14027.pdf
- C-Tran (Classificcation transformer): Transformer-based model for multi-label image classification that exploits dependencies among a target set of labels
- Training: Image & Mask Randome Label => C-Tran => predict masked label
- Learn to reconstruct a partial set of labels given randomly masked input label embedding
- Inference: Image & Mask Everything => C-Tran => predit all labels
- Predict a set of target labels by masking all the input labels as unknown
- Training: Image & Mask Randome Label => C-Tran => predict masked label
- Architecture:
- Image => ResNet 101 => feature embedding
Z
- Image => label embedding
L
(represent l possible label) => add with state embeddings
(with 3 state unknowU
, negativeN
, positiveP
) Z
&L + s
=> C-Tran =>L'
(predicted label embedding) => FFN => Y_hat (predicted label with posibility)- C-Tran:
- Skip[MHSA => Norm] => Linear => ReLU => Linear
- 3 transformer layers with 4 MHSA
- FFN: label inference classifier with single linear layer & sigmoid activation
- C-Tran:
- Image => ResNet 101 => feature embedding
- Label Mask Training:
- During training:
- Randomly mask a certain amount of label
- Given
L
possilbe label => number of "unknown" (masked) labels0.25L
<=n
<=L
- Given
- Using groundtruth of the other labels (via state embedding) => predict masked label (with cross entropy loss)
- Randomly mask a certain amount of label
- During training:
- Code: https://github.com/QData/C-Tran/
This section introduces several attention-based architectures for object detection (DETR, AnchorDETR,...) and segmentation tasks (MaskFormer, TransUNet...), even for a specific task such as hand detection (HandsFormer). These architectures majorly are the combination of Transformer and CNN backbone for these sophisticated tasks, but few are solely based on the Transformer architecture (FTN).
- Paper: https://arxiv.org/pdf/2005.12872.pdf
- Transformer Encoder & Decoder:
- Share the same architecture with the original transformer
- Encoder:
- Input sequence: flattened 2D feature (Image => CNN => flatten) + learnable fixed positional encoding (add to each layer)
- Output: encoder output in sequence
- Decoder:
- Input: Object queries (learned positional embeddings) + encoder output (input in the middle)
- Output: output embeddings
- The Decoder decode N objects in parallel at each decoder layer, not sequence one element at a time
- The model can reason about all objects together using pair-wise relations between them, while being able to use whole image as content
- Prediction FFN:
- 3-layer MLP with ReLU
- Output embeddings as input
- Predict normalized center coordinates, heigh and width of bounding box
- Architecture: Image => Backbone (CNN) => 2D representation => Flatten (+ Positional encoding) => Transformer Encoder-Decoder => Prediction FFN => bounding box
- Code: https://github.com/facebookresearch/detr
- Paper: https://arxiv.org/pdf/2109.07107.pdf
- Backbone: ResNet40
- The encoder/decoder share the same structure as DETR
- However, the self-attention in the encoder and decoder blocks are replaced by Row-Column Decouple Attention
- Row-Column Decouple Attention:
- Help reduce the GPU memeory when facing with high-resolution feature
- Main idea:
- Decouple key feature
Kf
into row featureKf,x
and column featureKf,y
by 1D global average pooling - Then perform the row attetion and column attentions separately
- Decouple key feature
- Code: https://github.com/megvii-research/AnchorDETR
- Paper: https://arxiv.org/pdf/2107.06278.pdf
- Pixel-level module:
- Image => Backbone (ResNet) => image feature
F
=> Pixel decoder (upsampling) => per-pixel embeddingE_pixel
- Image => Backbone (ResNet) => image feature
- Transformer module (Decoder only):
- Standard Transformer decoder
- N queries (learneable positional embeddings) +
F
(input in the middle) => Tranformer Decoder => N per-segment embeddingsQ
- Prediction in parallel (similar to DETR)
- Segmentation module:
Q
=> MLP (2 Linears + solfmax) => N mask embeddingsE_mask
& N class predictions- Dot_product(
E_mask
,E_pixel
) => sigmoid => Binary mask predictions - Matrix_mul(Mask predictions, class predictions) => Segmantic segmentation
- Code: https://github.com/facebookresearch/MaskFormer
- Paper: https://arxiv.org/pdf/2105.15203.pdf
- Input: Image
[H, W, 3]
=> patches of size 4x4 (rather than 16x16 like ViT)- Using smaller patches => favor the dense prediction task
- Do not need positional encoding (PE):
- Not necessary for semantic segmentation
- The resolution of PS is fixed => needs to be interpolated when facing different test resolutions => dropped accuracy
- Hierarchical Transformer Encoder: extract coarse and fine-grained features, partly inspired by ViT but optimized for semantic segmentation
- Overlap patch embeddings => [Transformer Block => Downsampling] x 4 times => CNN-like multi-level feature map
Fi
- Feature map size:
[H, W, 3]
=>F1
[H/4, W/4, C1]
=>F2
[H/8, W/8, C2]
=> ... =>F4
[H/32, W/32, C4]
- Provide both high and low-resolution features => boost the performance of semantic segmentation
- Transformer Block1: Efficient Self-Atnn => Mix-FNN => Overlap Patch Merging
- Efficient Self-Attention: Use the reduction ratio R on sequence length N = H x W (in particular, apply stride R) => reduce the complexity by R times
- Mix-FFN: Skip[MLP => Conv3x3 => GELU => MLP] which considers the effect of zero padding to leak location information (rather than positional encoding)
- Overlapped Patch Merging: similar to the image patch in ViT but overlap => combine feature patches
- Feature map size:
- Overlap patch embeddings => [Transformer Block => Downsampling] x 4 times => CNN-like multi-level feature map
- Lightweight All-MLP Decoder: fuse the multi-level features => predict semantic segmentation mask
- 1st Linear layer: unifying channel dimension of multi-level features
Fi
(from the encoder) Fi
are upsampler to 1/4th and concat together- 2nd Linear layer: fusing concatenated features
F
- 3rd Linear layer: predicting segmentation mask M
[H/4, W/4, N_cls]
withF
- 1st Linear layer: unifying channel dimension of multi-level features
- Code: https://github.com/lucidrains/segformer-pytorch
- Paper: https://arxiv.org/pdf/2105.05633.pdf
- Architecture:
- Image patches => Flatten & project => + Positional Encoding => Transformer Encoder => [Patch encoding + class embeddings] => Decoder (Mask Transformer) => scalar product => upsample & argmax => Segmentation Map
- Encoder:
- Input
z0
+ PE => Skip[Norm => MHSA] => Skip[Norm => 2 layers MLP] => Patch encodingZl
- Input
- Decoder:
- Input:
- Patch encoding
Zl
- K learnable class embeddings
cls
= [cls1,...cls_k] with K is the number of class. Each class embbedding is initial randomly & assigned to a single semantic class => generate the class mask
- Patch encoding
- Mask Transformer:
Zl
&cls
=> Transformer encoder =>Z'm
&c
- Masks(Z'm, c) =
L2_norm(Z'm).c^T
=> reshaped into 2DS_mask
=> bilinearly upsample to the original size [H, W, K] => softmax => final segmentation mask
- Input:
- Code: https://github.com/rstrudel/segmenter
- Paper: https://arxiv.org/pdf/2106.04108.pdf
- Fully Transformer Networks for semantic image segmentation, without relying on CNN.
- Both the encoder and decoder are composed of multiple transformer modules
- Pyramid Group Transformers (PGT) encoder to divide feature maps into multiple spatial groups => compute the representation for each
- Capable to handle spatial detail or local structure like CNN
- Reduce unaffordable computational & memory cost of the standard ViT; reduce feature resolution and increase the receptive field for extracting hierarchical features
- Feature Pyramid Transformer (FPT) decoder => fuse semantic-level & spatial level information from PGT encoders => high-resolution, high-level semantic output
- Architecture:
- Image => Patch => PGT Encoder => FPT Decoder => linear layer => bilinear upsampling => probability map => argmax(prob_map) => Segmentation
- PGT: four hierarchical stages that generate features with multiple scales, include Patch Transform (non-overlapping) + PGT Block to to extract hierarchical representations
- PGT Block: Skip[Norm => PG-MSA] => Skip[Norm => MLP]
- PG-MSA (Pyramid-group transformer block):
Head_ij
= Attention(Qij, Kij, Vij) =>hi
= reshape(Head_ij) => PG-MSA = Concat(hi
)
- FPT: aggregate the information from multiple levels of PGT encoder => generate finer semantic image segmentation
- The scale of FPT is not larger the better for segmentation (with limited segmentation training data) => determined by depth, embedding dim, and the reduction ratio of SR-MSA
- SR-MSA (Spatial-reduction transformer block): reduce memory and computation cost by spatially reducing the number of Key & Value tokiens, especially for high-resoluton representations
- The multi-level high-resolution feature of each branch => fusing (element-wise summation/channel-wise concatenation) => finer prediction
- Code: To be updated
- Paper: https://arxiv.org/pdf/2102.04306.pdf
- Downsampling (Encoder): using CNN-Transformer Hybrid
- (Medical) Image
[H, W, C]
=> CNN => 2D feature map => Linear Projection (Flatten into 2D Patch embedding) => Downsampling => Tranformer => Hidden feature[n_patch, D]
- CNN: downsampling by 1/2 => 1/4 => 1/8
- Transformer: Norm layer before MHSA/FFN (rather than applying Norm layer after MHSA/FFN like the original Transformer), total 12 layers
- Why using CNN-Transformer hybrid:
- Leverages the intermediate high-resolution CNN feature maps in the Decoder
- Performs better than the purge transformer
- (Medical) Image
- Upsamling (Decoder): using Cascaded Upsampler
- Similar to the upsamling part of the standard UNet
- Upsampling => concat with corresponded CNN feature map (from the Encoder) => Conv3x3 with ReLu
- Segmentation head (Conv1x1) at the final layer
- Hidden Feature
[n_patch, D]
=> reshape[D, H/16, W/16]
=>[512, H/16, H/16]
=>[256, H/8, W/8]
=>[128, H/4, W/4]
=>[64, H/2, W/2]
=>[16, H, W]
=> Segmentation head => Segmantic Segmentation
- Similar to the upsamling part of the standard UNet
- Code: https://github.com/KenzaB27/TransUnet
- Paper: https://arxiv.org/pdf/2107.00781.pdf
- Pipeline:
- Apply conv layers to extract local intensity feature, while using self-attention to capture long-range associative information
- UTNet follows the standard design of UNet, but replace the last conv of the building block in each resolution (except the highest one) with the proposed Transformer module
- Rather than using the convention MHSA like the standard Transformer, UTNet develops the Efficient MHSA (quite similar to the one in SegFormer):
- Using 2-dimensional relative position encoding by adding relative height and width information rather than the standard position encoding
- Code: https://github.com/yhygao/UTNet
- Paper: https://arxiv.org/pdf/2108.06747.pdf
- Combines the advantages of CNN and Transformer
- Architecture:
- Pipeline:
- Image => CNN Backbone => feature maps in multi-scale => patch recombination + positional embedding => clip-lvel feature sequences/blocks => Transformer => global-level semantic feature => functional heads => class & conv kernel prediction
- Backbone output => Multi-level upsampling model (with dynamic conv) => dynamic conv(output, Kerner head) => instance masks
- CNN Backbone: Feature pyramid network
- Transformer: proposed 2 different transformer designs with Twin attention:
- Twin attention: simplify the attention matrix with sparse representation (as the self-attention has both quadratic time and memory complicity) => higher computational cost
- Transformer layer:
- Pure twin layer: Skip[Norm => Twin Att.] => Skip[Norm => MLP]
- Hybrid twin layer: Skip[Norm => Twin Att.] => Skip[Conv3x3 => Leaky ReLU => Conv3x3]
- Hybrid Twin comes with the best performance
- Multi-level upsampling model: P5 feature map + Positional from transformer + P2-P4 from FPN => [Conv3x3 => Group Norm => Relu, multi stage] => upsample x2, x4, x8 (for P3-P5) => added together => point-wise conv => upsamping => final
HxW
feature map
- Pipeline:
- Code: https://github.com/easton-cau/SOTR
- Paper: https://arxiv.org/pdf/2104.14639.pdf
- Architecute:
- Image => UNet => Image features (from layers of UNet decoder) + Keypoint heatmap => bilinear interpolation & concat => FFN (3-layer MLP) => concat with keypoint heatmap (with positional encoding) => keypoint representation (likely to correspond to the 2D location of hand joints)
- Localizing the joints of hands in 2D is more accurate than directly regressing 3D location.
- The 2D keypoints are a very good starting point to predict an accurate 3D pose for both hands
- Keypoint representation => Transformer encoder => FFN (2-layer MLP + linear projection) => Keypoint Identity predictor [2 FC layer => linear projection => Softmax] => 2D pose
- [Joint queries, Transform encoder output] => Transform decoder => 2-layer MLP + linear projection => 3D pose
- Image => UNet => Image features (from layers of UNet decoder) + Keypoint heatmap => bilinear interpolation & concat => FFN (3-layer MLP) => concat with keypoint heatmap (with positional encoding) => keypoint representation (likely to correspond to the 2D location of hand joints)
- Code: To be updated
- Paper: https://arxiv.org/pdf/2108.02759.pdf
- Jointly learn global and local features in a layer-wise manner => solving the salient object detection task with the help of Transformer
- With the self-attention mechanism => transformer is capable to model the "contrast" => demonstrated to be crucial for saliency perception
- Architecture:
- Encoder:
- Input => split to grid of fixed-size patches => linear projection => feature vector (represent local details) + positional encoding => Encoder => encode global features without diluting the local ones
- The Encoder includes 17 layers of standard transformer encoder.
- Decoder:
- Decode the features with global-local information over the inputs and the previous layer from encoder by densely decode => preserve rich local & global features.
- The density decoder contains various types of decoder blocks, including:
- Naive Decoder: directly upsampling the outputs of last layer => same resolution of inputs => generating the saliency map. Specifically, 3 Conv-norm-ReLU are aplied => bilinearly upsample x16 => sigmoid
- Stage-by-Stage Decoder: upsamples the resolution x2 in each stage => miltigate the losses of spatial details. Specifically, 4 stage x [3 Conv-norm-ReLU => sigmoid]
- Multi-level Feature Decoder: sparsely fuse multi-level features, similar to the pyramid network. Specifically take feature F3, F6, F9, F12 (from the corresponed layers of encoder) => upsample x4 => several conv layers => fused => saliency maps
- Density Decoder: integrate all encoder layer features => upsample to the same spatial resolution of input (include pixel suffle & bilinear upsampling x2) => concat => salient feature => Conv => sigmoid => saliency map
- Encoder:
- Code: https://github.com/OliverRensu/GLSTR
- Paper: https://arxiv.org/pdf/2007.03815v2.pdf
- FA Module:
- FA module for non-local context aggregation for efficient segmentation.
- Fast Attention =
1/n*L2_norm(Q)*[L2_norm(K^T)*V]
- n = height x width
- L2_norm to ensure the affinity is between [-1, 1]
- Computational complexity O(n.c^2)
n/c
times more efficient than the standard self-attention, given n>>c
- Fast Attention Network (FANet) for real-time semantic segmentation:
- Encoder:
- Extract features from image input at different semantic levels
- Lightweight backbone ResNet-18/34 without last FC layer
- The first Res-block "Res-1" produces feature map of [h/4, w/4]
- The other subsequence blocks output feature maps with resolution downsampled by a factor of 2
- Apply Context Aggregation (FA module) at each stage
- Down-sapmpling Res-4 and Res4 for spatial reduction
- Context Aggregation:
- Basically, the FA module, composed of 3 Conv1x1 layers (without ReLU) for embedding input feature to be {Q, K, V}
- Decoder:
- Merges and upsamples the features (outputs of FA module) in sequence from deep feature maps to shallow ones
- Skip connection to connect middle features => enhanced decoded feature with high-level content
- Output with [h/4, w/4]
- Encoder:
- FA module extending for Video Semantic Segmentation:
- Extend FA module to spatial-temporal contexts => improves video semantic segmentation without increasing computational cost
- At frame t, the FA module can be calculated (through several transform steps) by:
FAt = 1/n*L2_norm(Qt)*[F(Kt, Vt) + sum(F(Kt-i, Vt-i))]
- Where
F(K, T) = L2_norm(K^T)*V
- Where
- However, the
F(Kt-i, Vt-i)
have already been computed in the prior steps and can be simply reused again for this step - Therefore, the spatial-temporal fast attention still maintains the complexity of O(n.c^2), which is not only fast but also free of
t
- Then, the normal FA module can be replaced by this spatial-temporal version for video semantic segmentation without increasing computational cost
- Code: https://github.com/feinanshan/FANet
This section introduces transformer-based architecture for few-shot learning, mainly for but not strictly to the object detection and segmentation area. Overall, these pipeline architectures are quite complex, so I recommend you should read the paper along with these reviewes for better understanding. Then again, this list will be updated frequently.
- Paper: https://arxiv.org/pdf/2103.11731.pdf
- Employ Deformable DETR and original Transformers as basic detection framework
- Architecture:
- Query image, Support images => Feature Extractor(ResNet-101) => Positional Encoding => Query/Support features => Correlational Aggregation Module => Support-Aggregated Query Feature => Transformer Encoder-Decoder => Prediction Head => Few-shot detection
- Correlation Aggregation Module (CAM):
- Key-compoment in Meta-DETR => aggregates query features with support classes => class-agnostic prediction
- Can aggregate multiple support classes simultaneously => capture inter-class correlations => reduce misclassification, enhance generalization
- Pipeline:
- Query & Support features => MHSA => ROIAlign + Average pooling (on the support feature only) => Query feature map
Q
& Support prototypesS
S
= Concat(S
, BG-Prototype); Task EncodingsT
= Concat(T
, BG-Encoding)- {
Q
,S
,T
} => Feature maching & Encoding matching (in parallel) => FFN => Support-Aggregated Query Features
- Query & Support features => MHSA => ROIAlign + Average pooling (on the support feature only) => Query feature map
- Feature matching:
- {Q
, Sigmoid(
S),
S} => Single-Head Attention => Element-wise multiplication => feature matching output
Qf`
- {Q
- Encoding Matching:
- {
T
,Q
,S
} => Single-Head Attention => encoding outputQe
- {
- FFN:
Element-wise Add(
Qf
,Qe
) => FFN => Support-Aggregated Query features
- Key-compoment in Meta-DETR => aggregates query features with support classes => class-agnostic prediction
- Transformer Encoder-Decoder:
- Follow the Deformable DETR architecture with 6 layer
- Adapt Multi-scale deformable attention, with the CAM is counted as one encoder layer
- Support-Aggregated Query features => Encoder-Decoder => Embedding
E
- Follow the Deformable DETR architecture with 6 layer
- Prediction Head:
E
=> FC, Sigmoid => Confidence scoreE
=> FC, ReLU + FC, ReLU + FC, Sigmoid => Bounding Box
- Code: https://github.com/ZhangGongjie/Meta-DETR
- Paper: https://arxiv.org/pdf/2108.02266.pdf
- Architecture:
- Support/Query Images => Backbone (VGG/ResNet) => PANet => PFENet =>
X
- PANet: Computer similarity between query features and support prototypes
- PFENet: Concat(Query features; Expanded Support Prototypes; Prior Mask)
X
=> Multi-scale Processing => Global/Local Enhancement Module (with 2 ouputT
andZ
respectively) => Concat(T, Z) => MLP => Few-shot Semantic Segmentation
- Support/Query Images => Backbone (VGG/ResNet) => PANet => PFENet =>
- Multi-scale Processing:
- Information over different scales (of the input feature maps
X
from support/query images) can be utilized - Multi-scaling with Global Average Pooling => feature Pyramid
Xi
= {X1, X2,...,Xn}
- Information over different scales (of the input feature maps
- Global Enhancement Module (GEM):
- Using Transformer => enhance the feature to exploit the global information
Xi
=> FC layer (for channel reduction) =>X'i
=> Feature Merging Unit (FMU) =>Yi
- FMU:
Yi
=X'i
if i=1;Yi
= (Conv1x1(Concat(X'i
,Ti-1
)) +X'i
) if i>1
- FMU:
Yi
=> [MHSA => MLP (2 Linear)] => MHSA =>Ti
=>T
- MHSA with GELU and Norm
- [MHSA => MLP] repeat L times with L = 3
T
= = Concat(T1, T2,...Tn) at layer L
- Local Enhancement Module (LEM):
- Follow the same pipeline as GEM:
Xi
=> FC layer => FMU =>Yi
=> Conv => output
- Rather than using transformer in GEM, LEM using Convolutional => encode the local information
- LEM output: {Z1, Z2,...,Zn}
- Follow the same pipeline as GEM:
- Segmentation Mask Prediction:
- Local/Global output:
Z
= Concat(Z1, T2,...Tn) Z
=> MLP => target maskM
- Local/Global output:
- Code: https://github.com/GuoleiSun/TRFS
- Paper: https://arxiv.org/pdf/2106.02320.pdf
- Architecture:
- Query & Support Images => Backbone (ResNet) => Image features
Xq
andXs
=> Concat the mask averaged support feature toXq
andXs
=> Flatten into 1D sequence by Conv1x1 (with shape HW x D) => (CyC-Transformer, L times) => reshaping => Conv-head => segmentation mask Xq
andXs
has token represented by featurez
at on pixel location => beneficial for segmentation
- Query & Support Images => Backbone (ResNet) => Image features
- CyC-Transformer:
- Self-Alignment block:
- Just like the original Transformer encoder
- Flatten query feature as input (query only)
- Pixel-wise feature of query images => aggregate their global context information
- Cross-alignment block:
- Replace MHSA with CyC-MHA
- Flatten query feature and sample of support feature as input
- Performs attention between query and support pixel-wise features => aggregate relevant support feature into query ones
- Cycle-Consistent Multi-head Attention (CyC-MHA):
- Alleviate the excessive harmful support features that confuse pure pixel-level attention
- Pipeline:
- Affinity map A is calculated => measure the correspondence relationship between all query and support tokens
- For a single token position j (j={0,...Ns}), its most similar point
i*
= argmax A(i
,j
) with i={0,...HqWq} is the index of flatten query feature - Construct Cycle-consistency (CyC) relationship for all tokens in the support sequence
- The cycle-consistency help avoids being bias by possible harmful feature effectively (facing when training for few-shot segmentation)
- Finally, CyC-MHA = softmax(Ai + B)V
- Where B (with only 2 value -inf and 0) is the additive bias element-wise added to aggregate support feature and V is the value sequence
- With B, the attention weight tends to be zero => irrelevant information will not be consider
- CyC encourages the consistency between most relative features between query and support => produce consistent feature representation
- Self-Alignment block:
- Conv-head:
- Output of CyC-MHA => reshaping to spatial dimensions => Conv-Head (Cov3x3 => ReLu => Conv1x1) => Segmentation Mask
- Code: To be updated
- Paper: https://arxiv.org/pdf/2108.03032.pdf
- Architecture:
- First stage: pre-train encoder/decoder (PSPNet pre-trained on ImageNet) with supervised learning => stronger representation
- Support/Query Image => Encoder-Decoder (PSPNet) => Linear clasifier with Support Mask (Support only) => Classifier (Support only) => [Classifier weight
Q
, Query featureK
, Query featureV
]
- Support/Query Image => Encoder-Decoder (PSPNet) => Linear clasifier with Support Mask (Support only) => Classifier (Support only) => [Classifier weight
- Second stage: meta-train the Classifier Weight Transformer (CWT) only (as the encoder-decoder capapble to capture generalization of unseen class)
- {Q, K, V} => Skip[Linear => MHSA => Norm] => Conv operation with
Q
=> Prediction Mask
- {Q, K, V} => Skip[Linear => MHSA => Norm] => Conv operation with
- First stage: pre-train encoder/decoder (PSPNet pre-trained on ImageNet) with supervised learning => stronger representation
- Code: https://github.com/zhiheLu/CWT-for-FSS
- Paper: https://openaccess.thecvf.com/content/CVPR2021/papers/Yang_Few-Shot_Transformation_of_Common_Actions_Into_Time_and_Space_CVPR_2021_paper.pdf
- The goal is localize the spatio-temporal tubelet of an action in an untrimmed query video based on the common action in the trimment support video
- Architecture:
- Pipeline:
- Untrimmed query video => split into clips + few support video => Video feature extractor (I3D) => spatio-temporal representation => Common Attention block => aligned with previous clip feature => query clip feature
- Query clip feature => Few-shot Transformer (FST) => fuse the support features into the query clip feature => aggregating with input embedding
- Top of output embedding (from FST) => Precition network => final tubelet prediction
- Video feature extractor:
- Adapt I3D as backbone to obtain spatio-temporal representation of a single query video & few support videos
- Support video => feed the whole video to the backbone directly
- Untrimmed query video => split into multiple clips => backbone network
- Common attention block: built on self-attention mechanism & non-local structural => model long-term spatio-temporal information
- Formula:
A^(I1,I2) = I1 + Linear(Norm[A(I1,I2)])
- With
A^
is the common attention, A is the standard attention, I1, I2 are 2 inputs
- With
- Common attention aligns each query clip feature with its previous clip features => contain more motion information => benefit the common action localization
- Formula:
- Few-shot Transformer (FST):
- Encoder: standard architecture with MHSA. The input is supplied with fixed spatio-temporal positional encoding.
- Support branch: the support video => encoder one by one => concat =>
Es
=> decoder (along withEq
) - Quary branch: enhanced query clip => encoder =>
Eq
=> decoder - The FFN can be a Conv1x1
- Support branch: the support video => encoder one by one => concat =>
- Decoder: 3 input [
Es
,Eq
, input embedding (learnt positional encoding)] => Common attention/MHSA => MHSA => FFN => Add&Norm => output embedding
- Encoder: standard architecture with MHSA. The input is supplied with fixed spatio-temporal positional encoding.
- Prediction network: output embedding => 3-layer FFN with ReLU => linear projection => normalized center coordinates => final action tubes for the whole untrimmed query video
- Pipeline:
- Code: To be updated
- Paper: https://arxiv.org/pdf/2006.11702.pdf
- URT layer is inspired from the standard Transformer network to effectively integrate the feature representations from the diverse set of training domains
- Uses an attention mechanism to learn to retrieve or blend the appropriate backbones to use for each task
- Training URT layer across few-shot tasks from many domains => support transfer across these tasks
- Architecture:
ri(x)
is the output vector from the backbone for domaini
=>r(x)
= concat[r1(x),...rm(x)] onm
pre-trained backbones- representation of Support set class
r(Sc) = sum[r(x)]/|Sc|
- For each class
c
, queryQc
= Linear(r(Sc)
) with weightWc
, biasbc
; For each domaini
and classc
, keyKic
= Linear(ri(Sc)
) withWk, bk
- Then, with
Qc
,Kic
, the scale dot-product attentionAic
is calculated as usual - The adapted representation for the support & query set examples is compute by
O(x) = Sum(Ai*ri(x)
- Finally, the multi-head URT
O(X)
is the concatenation of allO(x)
, just like usual.
- Code: https://github.com/liulu112601/URT
- Paper collections about Transformer in Computer Vision:
- There are plenty of ViT-based models and versions in this repository, in Pytorch:
- Paper collections about improving the Attention block (computational complexity):
These notes were created by quanghuy0497@2021