Skip to content

Latest commit

 

History

History
48 lines (37 loc) · 2.04 KB

README.md

File metadata and controls

48 lines (37 loc) · 2.04 KB

SeqCSG

license arxiv badge ICANN Pytorch

This repository contains code for:

Model Architecture

Dependencies

  • Python 3
  • PyTorch >= 1.7.1
  • Transformers>= 4.19.2
  • NumPy
  • All experiments are performed with one RTX 3090Ti GPU.

Prerequisites

  • DATASET: Download the Twitter2015 and Twitter2017 dataset followed instructions in the TomBERT repo, and place them in ./data/.
  • Cache: We use Scene-Graph-Benchmark to parse scene graphs from the images, then place imageid2triple.json and sub_images in ./cache/.

Code Structures

There are four parts in the code.

  • model: It contains the main files for SeqCSG network.
  • data: It contains the data splits for different datasets.
  • cache: It contains some cache files.
  • script: The training scripts for SeqCSG.

Train & Eval

The training script for Twitter2015:

bash scripts/run_2015.sh
[--dataset {twitter2015, twitter2017}] [--EPOCHS epoch] [--BATCH_SIZE batch_size] [--RANDOM_SEEDS seeds]
[--LEARNING_RATE learning_rate] [--image_feature image_encoder] [--triple_number numbers]

Note:

  • you can open the .sh file for parameter modification.