We provide the implementation to evaluate Hessian-based quantities (e.g., traces, top-eigenvalues, Hessian vector product) of graph neural networks. Our observation is that the Hessian-based measurements correlate better with observed generalization gaps of fine-tuned GNNs.
We use the Python packages in requirements.txt
for development. To install requirements:
pip install -r requirements.txt
- Dataset: Please download the dataset from the link and unzip it under the
./src
folder. - Pretrained models: Please download the pretrained model from the link and put them in the
./src/model_gin/
folder.
Our code is built on the project of "Strategies for Pre-training Graph Neural Networks" from Hu et al. Thanks to the authors for providing their implementation and dataset online.
Use the following scripts to compute Hessian-based measurements. We use Hessian vector multiplication tools from PyHessian (Yao et al., 2020).
compute_hessian_spectra.py
computes the trace and the eigenvalues of the loss's Hessian matrix of each layer in a neural network.compute_hessian_norms.py
computes the Hessian-based vector product.
Please follow the bash script examples to run the commands. Specify the checkpoint_name
and dataset
for computing the quantities.
python compute_hessian_spectra.py --input_model_file model_gin/supervised_contextpred.pth --split scaffold --gnn_type gin --dataset $dataset --batch_size 32 --device 0 --checkpoint_name $checkpoint_name
python compute_hessian_trace.py --input_model_file model_gin/supervised_contextpred.pth --split scaffold --gnn_type gin --dataset $dataset --batch_size 32 --device 0 --checkpoint_name $checkpoint_name
We also provides an algorithm that performs gradient updates on the perturbed weights of a graph neural network. Use finetune.py
to run experiments of fine-tuning on pretrained GNN models. Choose the dataset from sider, clintox, bace, bbbp, and tox21. Use --nsm_sigma
, --nsm_lam
, and --num_perturbs
to change the hyper-parameters: sigma, lambda, and number of perturbations. We search the sigma in
python finetune.py --input_model_file model_gin/supervised_masking.pth --split scaffold --gnn_type gin --dataset sider --device 0\
--train_nsm --nsm_sigma 0.1 --nsm_lam 0.6 --use_neg --reg_method penalty --lam_gnn 1e-4 --lam_pred 1e-4
python finetune.py --input_model_file model_gin/supervised_masking.pth --split scaffold --gnn_type gin --dataset clintox --device 0\
--train_nsm --nsm_sigma 0.05 --nsm_lam 0.4 --use_neg --reg_method penalty --lam_gnn 1e-4 --lam_pred 1e-4
python finetune.py --input_model_file model_gin/supervised_contextpred.pth --split scaffold --gnn_type gin --dataset bace --device 0\
--train_nsm --nsm_sigma 0.1 --nsm_lam 0.6 --use_neg --reg_method penalty --lam_gnn 1e-4 --lam_pred 1e-4
python finetune.py --input_model_file model_gin/supervised_contextpred.pth --split scaffold --gnn_type gin --dataset bbbp --device 0\
--train_nsm --nsm_sigma 0.05 --nsm_lam 0.6 --use_neg --reg_method penalty --lam_gnn 1e-4 --lam_pred 1e-4
python finetune.py --input_model_file model_gin/supervised_edgepred.pth --split scaffold --gnn_type gin --dataset tox21 --device 0\
--train_nsm --nsm_sigma 0.1 --nsm_lam 0.4 --use_neg --reg_method penalty --lam_gnn 1e-4 --lam_pred 1e-4
If you find this repository useful or happen to use it in a research paper, please cite our work with the following bib information.
@article{ju2023generalization,
title={Generalization in Graph Neural Networks: Improved PAC-Bayesian Bounds on Graph Diffusion},
author={Ju, Haotian and Li, Dongyue and Sharma, Aneesh and Zhang, Hongyang R},
journal={Artificial Intelligence and Statistics},
year={2023}
}
Thanks to the authors of the following repositories for providing their implementation publicly available, which greatly helps us develop this code.