This is the repository for Improving Molecular Representation Learning with Metric Learning-enhanced Optimal Transport [arXiv].
We design an optimal-transport based algorithm to tackle the regression tasks in the molecular representation learning and
improve the generalization capability of deep learning models. This work is published at Patterns.
pytorch >= 3.8.5
python >= 1.9.0
# Install packages
pip install pytorch
pip install einops
pip install POT
pip install kmeans-pytorch
pip install pytorch-metric-learning
pip install mendeleev rdkit-pypi
Awesome packages for implementing OT:
-
POT (recommend)
POT provides Machine Learning related solvers for DA problems. Moreover, you can add any sort of regularization to OT such as the Grourp Lasso. A good example to add generic regularizations is available there.
Specifically, you can add the conditional gradient through ot.optim.cg or the generalized conditional gradient through ot.optim.gcg. Notebly, ot.sinkhorn2 returns the loss while ot.sinkhorn does not. -
PyTorchOT
This supports Pytorch and is more friendly to gradient propagation. -
Sinkhorn-solver
A convenient solver using Pytorch and the writer only provides visualizations. -
OP meets ML
A great collection of OT resources.
Awesome packages for implementing K-means clustering:
- kmeans_pytorch (recommend)
It supports GPU usage of K-means with Pytorch. - KeOps There is an excellent simple implementation as well as its visualization of K-means algorithm.
Awesome packages for implementing metric learning:
- pytorch-metric-learning (recommend)
It offers very straightforward and easy-to-understand APIs to plan a triplet loss. Please take care of the version of Pytorch.
We examine our method across a wide range of benchmark datasets, including QM7/8/9, ESOL, FreeSolv, and Lipophilicity. Notably, all datasets except the adsorption dataset are open-accessed and broadly used by the AI4SCIENCE community. The adsorption dataset is newly collected and please cite our paper if you employ this database.
-
QM7
Download (Official Website):http://quantum-machine.org/datasets/
Download (DeepChem, recommend):https://github.com/deepchem/deepchem/blob/master/deepchem/molnet/load_function/qm7_datasets.py#L30-L107
Discription (DeepChem):https://deepchem.readthedocs.io/en/latest/api_reference/moleculenet.html#qm7-datasets
-
QM8
Download (DeepChem):https://github.com/deepchem/deepchem/blob/master/deepchem/molnet/load_function/qm8_datasets.py
Discription (DeepChem):https://deepchem.readthedocs.io/en/latest/api_reference/moleculenet.html?highlight=qm7#qm8-datasets
-
QM9
Download (Official Website):https://ndownloader.figshare.com/files/3195389
Download (Atom3D, recommend):https://www.atom3d.ai/smp.html
Download (Deep Chem):https://github.com/deepchem/deepchem/blob/master/deepchem/molnet/load_function/qm9_datasets.py
Download (MPNN Supplement):https://drive.google.com/file/d/0Bzn36Iqm8hZscHFJcVh5aC1mZFU/view?resourcekey=0-86oyPL3e3l2ZTiRpwtPDBg
Download (Schnet):https://schnetpack.readthedocs.io/en/stable/tutorials/tutorial_02_qm9.html#Loading-the-data
-
ESOL & FreeSolv & Lipophilicity
Download and Description (from Moleculnet, recommend):https://moleculenet.org/datasets-1
Download (from Glambard):https://github.com/GLambard/Molecules_Dataset_Collection
Description (From Deepchem):https://deepchem.readthedocs.io/en/latest/api_reference/moleculenet.html
Description (From DGL-sci):https://lifesci.dgl.ai/api/data.html#tox21
-
Adsorption Dataset
Raw Dataset: Baidu Cloud using This Link (passward:8uqy
)
Post-processed Dataset: Baidu Cloud using This Link (passward:82kq
)
We refer to some excellent implementations of baselines used in our paper.
-
DANN:
https://github.com/Yangyangii/DANN-pytorch
https://github.com/fungtion/DANN_py3
https://github.com/NaJaeMin92/pytorch_DANN
-
CDAN
https://github.com/thuml/CDAN
https://github.com/YBZh/CDAN-re-implement
https://github.com/agrija9/Deep-Unsupervised-Domain-Adaptation
-
IRM
IRM is not used as a baseline in our experiments because we regard the whole training set as a single domain.
https://github.com/facebookresearch/InvariantRiskMinimization
https://github.com/reiinakano/invariant-risk-minimization
-
MLDG
https://github.com/HAHA-DL/MLDG
https://github.com/yogeshbalaji/Meta-Learning-Domain-Generalization/blob/master/MLDG
First, you need to preprocess the molecular datasets to the format of Molformer. Then call main.py
to reproduce the results.
For instance, if you want to implement the domain adaptation task with MROT in QM7, you can use the following comment:
python main.py --data=qm7 --method=mrot
Awesome-domain-adaptation:
a collection of domain adaptation papers and their corresponding code.
Proof of Wasserstein Metric: Paper;
Lecture
Proof of Gromov-Wasserstein Metric:
Paper;
Paper
Thanks for your time to review our code repo for MROT. If you have any sort of problems, please do not hesitate to reach out to the author fw2359@columbia.edu
.
We would be really appreciate it if you find our study beneficial and cite it!
@article{wu2022metric,
title={Metric Learning-enhanced Optimal Transport for Biochemical Regression Domain Adaptation},
author={Wu, Fang and Courty, Nicolas and Qiang, Zhang and Li, Ziqing and others},
journal={arXiv preprint arXiv:2202.06208},
year={2022}
}