Skip to content

Latest commit

 

History

History
233 lines (186 loc) · 16.3 KB

README.md

File metadata and controls

233 lines (186 loc) · 16.3 KB

Few-Shot Graph Classification

In this project, of the Deep Learning and applied AI, I want to explore graph classification tasks when training data is not enough in order to build an accurate model, using standard Deep Learning technique. More precisely, I would like to give a brief description of Few-Shot Learning (FSL) and Meta-Learning (ML). Finally, I'm going to present some approaches in few-shot learning. First, a Meta-Learning Framework based on Fast Weight Adaptation and MAML (Model-Agnostic Meta-Learner), taken from the paper Adaptive-Step Graph Meta-Learner for Few-Shot Graph Classification (Ning Ma et al.). Second, I'm going to compre it with different GDA (graph data augmentation) techniques used to enrich the dataset for the novel classes (i.e., those with the less amount of data) taken from a second paper named Graph Data Augmentation for Graph Machine Learning: A Survey (Tong Zhao et al.).


1. Introduction

Few-Shot Learning

Most of the graph classification task overlook the scarcity of labeled graph in many situations. To overcome this problem, Few-Shot Learning is started being used. It is a type of machine learning method where the training data contains limited information. The general practice is to feed the machine learning model with as much data as possible, since this leads to better predictions. However, FSL aims to build accurate machine learning models with less training data. FSL aims to reduce the cost of gain and label a huge amount of data. In this project I'm going to concentrate on few-shot graph classification.

Which is the idea behind Few-Shot Classification? Let's say we have the train, validation and test set where train/validation and test set do not share any label for their data. First we sample $N$ class from those of the train/validation set and then for each class we sample $K + Q$ sample, for a total of $N \times (K + Q)$ graphs. The first $N \times K$ samples are called support set, while the latter $N \times Q$ composed the query set. Given labeled support data, the goal is to predict the labels of query data. Note that in a single task, support data and query data share the same class space. This is also called N-way-K-shot learning. At test stage when performing classification tasks on unseen classes, we firstly fine tune the learner on the support data of test classes, then report classification performance on the test query set. You can find more about FSL in A Survey on Few-Shot Learning (YAQING WANG et al.).

Meta-Learning

Humans learn really quickly from few examples, but what can we say about computers? In particular we can easily classify different objects of the real-world just after having seen very few examples, however current deep learning methods needs a huge amount of information in order to create a very precise model. Moreover, what if the test set has classes that we do not have in the training set? Or what if we want to test the model on a completely different task? Meta-Learning offers solutions to these situations. Why? It is also known as learn-to-learn: the goal is, obviously to learn a model that correctly classify already seen samples, but also to learn a model that quickly adapt to new classes and/or tasks with few samples. One of the most famous meta-learner is the so-called MAML.


2. Used Datasets

I decided to use the same datasets considered in the paper for AS-MAML: TRIANGLES, COIL-DEL, and Letter-High. All of them can be downloaded directly from this page, which is the origin of these datasets. Downloading from the previous page will result in a ZIP file with:

  • <dataname>_node_attributes.txt with the attribute vector for each node of each graph
  • <dataname>_graph_labels.txt with the class for each graph
  • <dataname>_graph_edges.txt with the edges for each graph expressed as a pair (node x, node y)
  • <dataname>_graph_indicator.txt that maps each nodes to its corresponding graph

Each of the dataset has been splitted into train, test and validation, and transformed into a python dictionaries finally saved as .pickle files. In this way we have a ready-to-be-used dataset. Moreover, each ZIP dataset containes three files:

  • <dataname>_node_attributes.pickle with the node attributes saved as a List or a torch Tensor
  • <dataname>_train_set.pickle with all the train data as python dictionaries
  • <dataname>_test_set.pickle with all the test data as python dictionaries
  • <dataname>_val_set.pickle with all the validation data as python dictionaries

These are the link from which you can directly download the datasets: TRIANGLES, COIL-DEL and Letter-High.

These are the statistics of the three datasets

DATASET |G| Avg.|V| Avg.|E|
TRIANGLES 45000 28.85 35.50
Letter-High 2250 4.67 4.50
COIL-DEL 3900 21.54 54.22

3. Project structure

In this section I'm going to describe the structure of this project.

.
├── data                       # Contains the datasets (TRIANGLES, COIL-DEL, R52 and Letter-High)
├── models                     # Contains pre-trained models for each of the different tests done
├── src                        # Source files of the project
│   ├── algorithms             # Contains all the algorithm used in the project
│   │   ├── asmaml             # Contains code for the AS-MAML
│   │   │   ├── __init__.py    
│   │   │   ├── README.md       
│   │   │   └── asmaml.py      
│   │   ├── mevolve            # Contains code for M-Evolve
│   │   │   ├── __init__.py
│   │   │   ├── README.md
│   │   │   └── mevolve.py     
│   │   ├── flag               # Contains code for FLAG
│   │   │   ├── __init__.py
│   │   │   ├── README.md
│   │   │   └── flag.py
│   │   ├── gmixup             # Contains code for G-Mixup
│   │   │   ├── __init__.py
│   │   │   ├── README.md
│   │   │   └── gmixup.py
│   ├── data                   # Contains code for dataset, dataloader and sampler
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   ├── sampler.py
│   │   └── dataloader.py
│   ├── models                 # Contains various convolutional layer and models
│   │   ├── __init__.py
│   │   ├── conv.py
│   │   ├── gcn4maml.py
│   │   ├── linear.py
│   │   ├── nis.py
│   │   ├── pool.py
│   │   ├── sage4maml.py
│   │   ├── stopcontrol.py
│   │   └── utils.py
│   ├── utils
│   │   ├── __init__.py
│   │   ├── kfold.py
│   │   ├── testers.py
│   │   ├── trainer.py
│   │   └── utils.py
│   ├── __init__.py
│   ├── config.py
│   └── main.py
└── README.md

4. Installation and usage

To run the project you will need to install all the dependencies, so the suggested procedure is to create a virtual environment first, with the command python -m venv <venv_name>, and then install the required libraries:

  • torch==1.12.1 or torch==1.12.1+cu116 (or other versions of CUDA) (more info)
  • torch-geometric (more info)
  • numpy (latest)
  • matplotlib (latest)
  • networkx (latest)
  • sklearn (latest)

Then you can type python main.py --help or python main.py -h to obtain the following output, and see which commands to use for configure and run the project

usage: main.py [-h] [-p PATH] [-n NAME] [-d DEVICE] [-l LOG_PATH] [-f] [-s SAVE_PATH] [-m MODEL] [--not-as-maml] [--gmixup] [--flag] [--mevolve]
               [--batch-size BATCH_SIZE] [--outer_lr OUTER_LR] [--inner_lr INNER_LR] [--stop_lr STOP_LR] [--w-decay W_DECAY] [--max-step MAX_STEP]
               [--min-step MIN_STEP] [--penalty PENALTY] [--train-shot TRAIN_SHOT] [--val-shot VAL_SHOT] [--train-query TRAIN_QUERY]
               [--val-query VAL_QUERY] [--train-way TRAIN_WAY] [--test-way TEST_WAY] [--val-episode VAL_EPISODE] [--train-episode TRAIN_EPISODE]
               [--batch-episode BATCH_EPISODE] [--epochs EPOCHS] [--patience PATIENCE] [--grad-clip GRAD_CLIP] [--scis SCIS] [--schs SCHS] [--beta BETA]
               [--n-fold N_FOLD] [--n-xval N_XVAL] [--iters ITERS] [--heuristic HEURISTIC] [--lrts LRTS] [--lrtb LRTB] [--flag-m FLAG_M] [--ass ASS]

options:
  -h, --help            show this help message and exit
  -p PATH, --path PATH  The path of the dataset (default: /home/fscg/app/data)
  -n NAME, --name NAME  The name of the dataset (default: COIL-DEL)
  -d DEVICE, --device DEVICE
                        The device to use (default: cpu)
  -l LOG_PATH, --log-path LOG_PATH
                        The path where to log (default: None)
  -f, --file-log        If logging to file or not (default: False)
  -s SAVE_PATH, --save-path SAVE_PATH
                        The path where to save pre-trained models (default: /home/fscg/app/models)
  -m MODEL, --model MODEL
                        The name of the model (sage or gcn) (default: sage)
  --not-as-maml         Use AS-MAML or not (default: True)
  --gmixup              Use G-Mixup or not (default: False)
  --flag                Use FLAG or not (default: False)
  --mevolve             Use M-Evolve or not (default: False)
  --batch-size BATCH_SIZE
                        Dimension of a batch (default: 1)
  --outer_lr OUTER_LR   Initial LR for the model (default: 0.001)
  --inner_lr INNER_LR   Initial LR for the meta model (default: 0.01)
  --stop_lr STOP_LR     Initial LR for the Stop model (default: 0.0001)
  --w-decay W_DECAY     The Weight Decay for optimizer (default: 1e-05)
  --max-step MAX_STEP   The Max Step of the meta model (default: 15)
  --min-step MIN_STEP   The Min Step of the meta model (default: 5)
  --penalty PENALTY     Step Penality for the RL model (default: 0.001)
  --train-shot TRAIN_SHOT
                        The number of Shot per Training (default: 10)
  --val-shot VAL_SHOT   The number of shot per Validation (default: 10)
  --train-query TRAIN_QUERY
                        The number of query per Training (default: 15)
  --val-query VAL_QUERY
                        The number of query per Validation (default: 15)
  --train-way TRAIN_WAY
                        The number of way for Training (default: 3)
  --test-way TEST_WAY   The number of way for Test and Val (default: 3)
  --val-episode VAL_EPISODE
                        The number of episode for Val (default: 200)
  --train-episode TRAIN_EPISODE
                        The number of episode for Training (default: 200)
  --batch-episode BATCH_EPISODE
                        The number of batch per episode (default: 5)
  --epochs EPOCHS       The total number of epochs (default: 500)
  --patience PATIENCE   The patience (default: 35)
  --grad-clip GRAD_CLIP
                        The clipping for the gradient (default: 5)
  --scis SCIS           The input dimension for the Stop Control model (default: 2)
  --schs SCHS           The hidden dimension for the Stop Control model (default: 20)
  --beta BETA           The beta used in heuristics of M-Evolve (default: 0.15)
  --n-fold N_FOLD       The number of Fold for the nX-fol-validation (default: 5)
  --n-xval N_XVAL       Number of Cross-fold Validation to run (default: 10)
  --iters ITERS         Number of iterations of M-Evolve (default: 5)
  --heuristic HEURISTIC
                        The Heuristic to use (default: random_mapping)
  --lrts LRTS           The label reliability step thresholds (default: 1000)
  --lrtb LRTB           The beta used for approximation of the tanh (default: 30)
  --flag-m FLAG_M       The number of iterations of FLAG (default: 3)
  --ass ASS             The attack step size (default: 0.008)

4.1 Examples of runs

Let's assume you want to run just a simple AS-MAML training with 200 epochs on the Letter-High Dataset using the GPU. Then you have to run

$ python main.py --name Letter-High --epochs 200 --device gpu

Another example could be: run AS-MAML plus MEvolve with 200 training epochs on TRIANGLES changing the number of MEvolve iterations.

$ python main.py --name TRIANGLES --epochs 200 --mevolve --use-pretrained --iters 10

Note that, for MEvolve you need to have a pretrained plain AS-MAML model in the models folder. Finally, to run tests just give

$ python main.py --name TRIANGLES --test ../models/TRIANGLES_AdaptiveStepMAML_SAGE4MAML_MEvolve_bestModel.pth

4.2. Docker

Alternatively, I have already created a Docker Image that can be pulled with docker pull lmriccardo/fsgc:1.0. Then, you need to run the container with docker run --rm -it lmriccardo/fsgc:1.0 and, finally, run the same python command given above: python main.py [--flags ...].


5. Algorithms and Models

As I said the goal of this projects is to compare different Graph Data Augmentation techniques for few-shot learning, and more precisely for few-shot classification. For this reason all the techniques that have been chosen regard augmentation for classification, i.e. techniques that try to preserve structural properties of the original graph. The overall idea is to generate new data for already defined labels, this means without additionally labeling those new data, by using some procedures of dropping edge/nodes, change in node features or randomly generate graphs based on so-called graphons. To this end I decided to use this three GDA technique, one per type:

  • Model-Evolution. GDA technique that uses edge dropping based on motifs similarity (paper)
  • FLAG. GDA technique that uses perturbation attacks to perturb node features (paper)
  • G-Mixup. GDA technique that uses Mixup on graphs via graphons (paper)

For quick further informations about each of the three techniques I suggest to have a look to their respectively README that you can found at ./src/algorithms/mevolve (for M-Evolve), ./src/algorithms/flag (for FLAG) and ./src/algorithms/gmixup (for G-Mixup).

Finally, I decided to use as baseline comparision results and performances given by AS-MAML run using a SAGE (for MAML) model.


6. Results

Each dataset has been trained, validated and tested using the Graph SAGE model: 3 SAGE convolutional layers, 3 SAGPool layers and 3 final FC layer, all of them using the LeakyRELU activation function except for the last linear layer for which the softmax is used. Each model has been trained using 200 epochs each of them running 200 training and validation episodes, then tested using only 200 testing episodes. Finally, the configuration for the few-shot sampling is the same for all dataset: 3 classes (way) for train, test and validation, 10 samples (shot) for support train and test/validation and 15 samples (query) for query train and test/validation. At the end, each episode had 75 graphs. These are the otained results

A% AS-MAML +M-Evolve +G-Mixup +FLAG
TRIANGLES 82.47 85.48 84.91 83.73
COIL-DEL 86.51 92.12 90.73 89.98
Letter-High 59.86 62.39 61.07 60.68