Official PyTorch Implementation
Avi Ben-Cohen, Nadav Zamir, Emanuel Ben Baruch, Itamar Friedman, Lihi Zelnik-Manor
DAMO Academy, Alibaba Group Abstract Training a neural network model for recognizing multiple labels associated with an image, including identifying unseen labels, is challenging, especially for images that portray numerous semantically diverse labels. As challenging as this task is, it is an essential task to tackle since it represents many real-world cases, such as image retrieval of natural images. We argue that using a single embedding vector to represent an image, as commonly practiced, is not sufficient to rank both relevant seen and unseen labels accurately. This study introduces an end-to-end model training for multi-label zero-shot learning that supports semantic diversity of the images and labels. We propose to use an embedding matrix having principal embedding vectors trained using a tailored loss function. In addition, during training, we suggest up-weighting in the loss function image samples presenting higher semantic diversity to encourage the diversity of the embedding matrix. Extensive experiments show that our proposed method improves the zero-shot model’s quality in tag-based image retrieval achieving SoTA results on several common datasets (NUS-Wide, COCO, Open Images).
In this PyTorch file, we provide an implementation of our semantic diversity learning (SDL) loss for zero shot multi-label classification.
We provide a pre-trained model on NUS-WIDE dataset, which can be found here
We provide an inference code, that demonstrates how to load our model, pre-process an image and do actuall inference. Example run:
python infer.py \
--model_path=./models_local/NUS_mtresnet_224.pth \
--model_name=tresnet_m \
--pic_path=./pics/140016_215548610_422b79b4d7_m.jpg \
--top_k=10 \
which will result in: Note that predicted "unseen" tags are indicated by * tag-name *.
We provide a training code, that can be used to train our model.
- The implementation in the provided training script is based on the ASL repository.
- The annotations should be provided in COCO format.
- To reproduce similar results to our paper results on COCO use the split provided in: Zero-Shot Object Detection.
- The annotation files are expected to be in the metadata path under "zs_split" folder.
- wordvec_array.pickle and cls_ids.pickle include coco word-vectors and seen-uneen class ids respectively, and should be located in the metadata path.
- The pretrained imagenet based backbone can be downloaded here
- Run the following training args:
python train.py \
--data=./data/COCO/ \
--model-path=./models/tresnet_m.pth \
--image-size=608 \
--pretrain-backbone=1 \
Note: the resolution is higher as we compared to object detection based methods that use similar or larger input size.
@misc{bencohen2021semantic,
title={Semantic Diversity Learning for Zero-Shot Multi-label Classification},
author={Avi Ben-Cohen and Nadav Zamir and Emanuel Ben Baruch and Itamar Friedman and Lihi Zelnik-Manor},
year={2021},
eprint={2105.05926},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Several images from NUS-WIDE dataset are used in this project. Some components of this code implementation are adapted from the repository https://github.com/Alibaba-MIIL/ASL. We would like to thank Tal Ridnik for his valuable comments and suggestions.