Skip to content

Commit

Permalink
unimo_tools 0.1.0 (#233)
Browse files Browse the repository at this point in the history
* 0.1.0

* 0.1.0

---------

Co-authored-by: gaozf <gaozf@dp.tech>
  • Loading branch information
Naplessss and gaozf authored Jun 20, 2024
1 parent 43f71d5 commit 6520ed4
Show file tree
Hide file tree
Showing 23 changed files with 481 additions and 683 deletions.
11 changes: 5 additions & 6 deletions unimol_tools/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ Documentation of Uni-Mol tools is available at https://unimol.readthedocs.io/en/
* [unimol representation](https://bohrium.dp.tech/notebook/f39a7a8836134cca8e22c099dc9654f8)

## install
- Notice: [Uni-Core](https://github.com/dptech-corp/Uni-Core) is needed, please install it first. Current Uni-Core requires torch>=2.0.0 by default, if you want to install other version, please check its [Installation Documentation](https://github.com/dptech-corp/Uni-Core#installation).
- pytorch is required, please install pytorch according to your environment. if you are using cuda, please install pytorch with cuda. More details can be found at https://pytorch.org/get-started/locally/
- currently, rdkit needs with numpy<2.0.0, please install rdkit with numpy<2.0.0.
```python
## unicore and other dependencies installation
## dependencies installation
pip install -r requirements.txt
## clone repo
git clone https://github.com/dptech-corp/Uni-Mol.git
Expand All @@ -18,9 +19,6 @@ cd Uni-Mol/unimol_tools/unimol_tools
## download pretrained weights
wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/mol_pre_all_h_220816.pt
wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/mol_pre_no_h_220816.pt
wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/pocket_pre_220816.pt
wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/mof_pre_no_h_CORE_MAP_20230505.pt
wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/mp_all_h_230313.pt
wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/oled_pre_no_h_230101.pt

mkdir -p weights
Expand All @@ -32,7 +30,8 @@ python setup.py install
```

## News
- unimol_tools documents is coming soon.
- 2024-06-20: unimol_tools v0.1.0 released, we remove the dependency of Uni-Core. And we will publish to pypi soon.
- 2024-03-20: unimol_tools documents is available at https://unimol.readthedocs.io/en/latest/

## molecule property prediction
```python
Expand Down
21 changes: 7 additions & 14 deletions unimol_tools/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
pandas
numpy==1.22.4
pandas==1.4.0
scikit-learn==1.5.0
torch
joblib
rdkit
pymatgen
pyyaml
addict
tqdm
yacs
transformers
wandb
iopath
lmdb
ml_collections
numpy
scipy
tensorboardX
tokenizers
git+https://github.com/dptech-corp/Uni-Core.git
tqdm
14 changes: 11 additions & 3 deletions unimol_tools/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,28 @@

setup(
name="unimol_tools",
version="1.0.0",
version="0.1.0",
description=("unimol_tools is a Python package for property prediciton with Uni-Mol in molecule, materials and protein."),
author="DP Technology",
author_email="unimol@dp.tech",
license="The MIT License",
url="https://github.com/dptech-corp/Uni-Mol",
url="https://github.com/dptech-corp/Uni-Mol/unimol_tools",
packages=find_packages(
where='.',
exclude=[
"build",
"dist",
],
),
install_requires=["yacs", "addict", "tqdm", "transformers", "pymatgen"],
install_requires=["numpy<2.0.0,>=1.22.4",
"pandas<2.0.0",
"torch",
"joblib",
"rdkit",
"pyyaml",
"addict",
"scikit-learn",
"tqdm"],
python_requires=">=3.6",
include_package_data=True,
classifiers=[
Expand Down
2 changes: 1 addition & 1 deletion unimol_tools/unimol_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .train import MolTrain
from .predict import MolPredict
from .predictor import MOFPredictor, UniMolRepr
from .predictor import UniMolRepr
2 changes: 0 additions & 2 deletions unimol_tools/unimol_tools/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
"molecule_no_h": "mol_pre_no_h_220816.pt",
"molecule_all_h": "mol_pre_all_h_220816.pt",
"crystal": "mp_all_h_230313.pt",
"mof": "mof_pre_no_h_CORE_MAP_20230505.pt",
"oled": "oled_pre_no_h_230101.pt",
},
"dict":{
"protein": "poc.dict.txt",
"molecule_no_h": "mol.dict.txt",
"molecule_all_h": "mol.dict.txt",
"crystal": "mp.dict.txt",
"mof": "mof.dict.txt",
"oled": "oled.dict.txt",
},
}
2 changes: 1 addition & 1 deletion unimol_tools/unimol_tools/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .datahub import DataHub
from .datareader import MOFReader
from .dictionary import Dictionary
46 changes: 1 addition & 45 deletions unimol_tools/unimol_tools/data/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@
from __future__ import absolute_import, division, print_function

import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import RDLogger
import warnings
from scipy.spatial import distance_matrix
RDLogger.DisableLog('rdApp.*')
warnings.filterwarnings(action='ignore')
from unicore.data import Dictionary
from .dictionary import Dictionary
from multiprocessing import Pool
from tqdm import tqdm
import pathlib
Expand Down Expand Up @@ -209,48 +207,6 @@ def coords2unimol(atoms, coordinates, dictionary, max_atoms=256, remove_hs=True,
# edge type
src_edge_type = src_tokens.reshape(-1, 1) * len(dictionary) + src_tokens.reshape(1, -1)

return {
'src_tokens': src_tokens.astype(int),
'src_distance': src_distance.astype(np.float32),
'src_coord': src_coord.astype(np.float32),
'src_edge_type': src_edge_type.astype(int),
}


def coords2unimol_mof(atoms, coordinates, dictionary, max_atoms=256):
'''
Converts atomic symbols and their coordinates to a unimolecular metal-organic framework (MOF) representation that is suitable for input to a neural network.
This function handles cropping of atoms and coordinates if the number exceeds the maximum allowed, tokenization of atomic symbols, normalization and padding of coordinates, and computation of a distance matrix.
:param atoms: (list or np.ndarray) A list of atomic symbols (e.g., ['C', 'H', 'O']).
:param coordinates: (list or np.ndarray) A list of 3D coordinates corresponding to the atoms (shape: [num_atoms, 3]).
:param dictionary: A dictionary-like object that maps atomic symbols to unique integer tokens and provides methods to access special tokens such as 'bos' (beginning of sequence) and 'eos' (end of sequence).
:param max_atoms: (int) The maximum number of atoms to consider; atoms beyond this number are randomly cropped.
:return: A dictionary containing tokenized atomic symbols ('src_tokens'), a distance matrix ('src_distance'), normalized and padded coordinates ('src_coord'), and edge types ('src_edge_type').
'''
atoms = np.array(atoms)
coordinates = np.array(coordinates).astype(np.float32)
### cropping atoms and coordinates
if len(atoms)>max_atoms:
idx = np.random.choice(len(atoms), max_atoms, replace=False)
atoms = atoms[idx]
coordinates = coordinates[idx]
### tokens padding
src_tokens = np.array([dictionary.bos()] + [dictionary.index(atom) for atom in atoms] + [dictionary.eos()])
src_distance = np.zeros((len(src_tokens), len(src_tokens)))
### coordinates normalize & padding
src_coord = coordinates - coordinates.mean(axis=0)
src_coord = np.concatenate([np.zeros((1,3)), src_coord, np.zeros((1,3))], axis=0)
### distance matrix
# src_distance = distance_matrix(src_coord, src_coord)
src_distance = np.zeros((len(src_tokens), len(src_tokens)))
src_distance[1:-1,1:-1] = distance_matrix(src_coord[1:-1], src_coord[1:-1])

### edge type
src_edge_type = src_tokens.reshape(-1, 1) * len(dictionary) + src_tokens.reshape(1, -1)

return {
'src_tokens': src_tokens.astype(int),
'src_distance': src_distance.astype(np.float32),
Expand Down
9 changes: 0 additions & 9 deletions unimol_tools/unimol_tools/data/datahub.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,10 @@
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function

import logging
import copy
import os
import pandas as pd
import numpy as np
import csv
from typing import List, Optional
from collections import defaultdict
from .datareader import MolDataReader
from .datascaler import TargetScaler
from .conformer import ConformerGen
from ..utils import logger

class DataHub(object):
"""
Expand Down
115 changes: 0 additions & 115 deletions unimol_tools/unimol_tools/data/datareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,11 @@

from __future__ import absolute_import, division, print_function

import logging
import copy
import os
import pandas as pd
import re
from pymatgen.core import Structure
from .conformer import inner_coords, coords2unimol_mof
from unicore.data import Dictionary
import numpy as np
import csv
from typing import List, Optional
from rdkit import Chem
from ..utils import logger
from ..config import MODEL_CONFIG
import pathlib
from rdkit.Chem.Scaffolds import MurckoScaffold
WEIGHT_DIR = os.path.join(pathlib.Path(__file__).resolve().parents[1], 'weights')
Expand Down Expand Up @@ -197,109 +188,3 @@ def anomaly_clean_regression(self, data, target_cols):
data = data[(data[target_col] > _mean - 3 * _std) & (data[target_col] < _mean + 3 * _std)]
logger.info('Anomaly clean with 3 sigma threshold: {} -> {}'.format(sz, data.shape[0]))
return data


class MOFReader(object):
'''A class to read MOF data.'''
def __init__(self):
"""
Initialize the MOFReader object with predefined gas lists, gas ID mappings,
gas attributes, dictionary name from the model configuration, and a loaded
dictionary for atom types. Sets the maximum number of atoms in a structure.
"""
self.gas_list = ['CH4','CO2','Ar','Kr','Xe','O2','He','N2','H2']
self.GAS2ID = {
"UNK":0,
"CH4":1,
"CO2":2,
"Ar":3,
"Kr":4,
"Xe":5,
"O2":6,
"He":7,
"N2":8,
"H2":9,
}
self.GAS2ATTR = {
"CH4":[0.295589,0.165132,0.251511019,-0.61518,0.026952,0.25887781],
"CO2":[1.475242,1.475921,1.620478155,0.086439,1.976795,1.69928074],
"Ar":[-0.11632,0.294448,0.1914686,-0.01667,-0.07999,-0.1631478],
"Kr":[0.48802,0.602454,0.215485568,1.084671,0.415991,0.39885917],
"Xe":[1.324657,0.751519,0.233498293,2.276323,1.12122,1.18462811],
"O2":[-0.08095,0.37909,0.335570404,-0.61626,-0.5363,-0.1130181],
"He":[-1.66617,-1.88746,-2.15618995,-0.9173,-1.36413,-1.6042445],
"N2":[-0.37636,-0.3968,0.41962979,-0.31495,-0.40022,-0.3355659],
"H2":[-1.34371,-1.3843,-1.11145188,-0.96708,-1.16031,-1.3256695],
}
self.dict_name = MODEL_CONFIG['dict']['mof']
self.dictionary = Dictionary.load(os.path.join(WEIGHT_DIR, self.dict_name))
self.dictionary.add_symbol("[MASK]", is_special=True)
self.max_atoms = 512

def cif_parser(self, cif_path, primitive=False):
"""
Parses a single CIF file to extract structural information.
:param cif_path: (str) Path to the CIF file.
:param primitive: (bool) Whether to use the primitive cell.
:return: A dictionary containing structural information such as ID, atoms,
coordinates, lattice parameters, and volume.
"""
s = Structure.from_file(cif_path, primitive=primitive)
id = cif_path.split('/')[-1][:-4]
lattice = s.lattice
abc = lattice.abc # lattice vectors
angles = lattice.angles # lattice angles
volume = lattice.volume # lattice volume
lattice_matrix = lattice.matrix # lattice 3x3 matrix

df = s.as_dataframe()
atoms = df['Species'].astype(str).map(lambda x: re.sub("\d+", "", x)).tolist()
coordinates = df[['x', 'y', 'z']].values.astype(np.float32)
abc_coordinates = df[['a', 'b', 'c']].values.astype(np.float32)
assert len(atoms) == coordinates.shape[0]
assert len(atoms) == abc_coordinates.shape[0]

return {'ID':id,
'atoms':atoms,
'coordinates':coordinates,
'abc':abc,
'angles':angles,
'volume':volume,
'lattice_matrix':lattice_matrix,
'abc_coordinates':abc_coordinates,
}

def gas_parser(self, gas='CH4'):
"""
Parses information about a specific gas.
:param gas: (str) The name of the gas.
:return: A dictionary containing the ID and attributes for the specified gas.
:raises AssertionError: If the specified gas is not in the supported gas list.
"""
assert gas in self.gas_list, "{} is not in list, current we support: {}".format(gas, '-'.join(self.gas_list))
gas_id = self.GAS2ID.get(gas, 0)
gas_attr = self.GAS2ATTR.get(gas, np.zeros(6))

return {'gas_id': gas_id, 'gas_attr': gas_attr}

def read_with_gas(self, cif_path, gas):
"""
Reads CIF file and gas information, and combines them into a single dictionary.
:param cif_path: (str) Path to the CIF file.
:param gas: (str) The name of the gas to be read.
:return: A dictionary containing both the structural information from the CIF file
and the attributes of the specified gas.
"""
dd = self.cif_parser(cif_path)
atoms, coordinates = inner_coords(dd['atoms'], dd['coordinates'])
dd = coords2unimol_mof(atoms, coordinates, self.dictionary, max_atoms=self.max_atoms)
dd.update(self.gas_parser(gas))

return dd
5 changes: 0 additions & 5 deletions unimol_tools/unimol_tools/data/datascaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@

from __future__ import absolute_import, division, print_function

import logging
import copy
import os
import pandas as pd
import numpy as np
import csv
from typing import List, Optional
import joblib
from sklearn.preprocessing import (
StandardScaler,
Expand Down
Loading

0 comments on commit 6520ed4

Please sign in to comment.