From 8ee5ce194a67d9ac07d52efa245ae3f836cf794b Mon Sep 17 00:00:00 2001 From: Leon Oostrum Date: Thu, 19 Oct 2023 13:33:52 +0200 Subject: [PATCH 1/6] Add chemistry related files to gitignore --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index 2bdab8de..ccad7e85 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,9 @@ venv venv3 .python-version + +# deeprank/chemistry files +*.pdb +*.hdf5 +*.cube +*.pckl From 2a013fe46e8bb8e8500318681419c42e616c8369 Mon Sep 17 00:00:00 2001 From: Leon Oostrum Date: Thu, 19 Oct 2023 13:50:46 +0200 Subject: [PATCH 2/6] Add docstring for get_feature_importance, add initial empty function to run dianna with single feature --- dianna/deeprank/__init__.py | 3 ++ dianna/deeprank/feature_importance.py | 45 ++++++++++++++++++--------- dianna/deeprank/single_feature.py | 2 ++ 3 files changed, 36 insertions(+), 14 deletions(-) create mode 100644 dianna/deeprank/single_feature.py diff --git a/dianna/deeprank/__init__.py b/dianna/deeprank/__init__.py index ad2c073a..3b2eae4c 100644 --- a/dianna/deeprank/__init__.py +++ b/dianna/deeprank/__init__.py @@ -4,3 +4,6 @@ if importlib.util.find_spec('deeprank') is None: raise ImportError("Cannot find deeprank, please install manually or reinstall dianna with " "chemistry support: `pip install dianna[chem]`") + +from .feature_importance import get_feature_importance +from .single_feature import run_single_feature diff --git a/dianna/deeprank/feature_importance.py b/dianna/deeprank/feature_importance.py index 28c8be89..1f232a4a 100644 --- a/dianna/deeprank/feature_importance.py +++ b/dianna/deeprank/feature_importance.py @@ -10,14 +10,31 @@ @click.command() -@click.option('--model', 'model_path', help='Path to deeprank model state', required=True) -@click.option('--dataset', 'dataset_path', help='Path to deeprank dataset', required=True) +@click.option('--model', 'model_path', help='Path to deeprank model state (tar file)', required=True) +@click.option('--dataset', 'dataset_path', help='Path to deeprank dataset (hdf5 file)', required=True) @click.option('--sample', 'sample_name', help='Name of sample within dataset to analyse', required=True) @click.option('--target_class', type=int, help='Which class to show output for [Default: real target class of input sample]') @click.option('--topn', type=int, default=10, show_default=True, help='Show this many features in the output') @click.option('--normalize', is_flag=True, show_default=True, default=False, help='Normalize the dataset') -def get_feature_importance(model_path, dataset_path, sample_name, topn, target_class, normalize): +def get_feature_importance(model_path, dataset_path, sample_name, topn, target_class, normalize, plot=True): + """ + Run the model, masking the features completely one by one. Then check how much the model output changes. + The feature which, when masked, induces the largest change in the model output is deemed most important + + Args: + model_path (str): Path to deeprank model state (tar file) + dataset_path (str): Path to deeprank dataset (hdf5 file) + sample_name (str): Name of sample within dataset to analyse + topn (int): Show this many features in the output + target_class (int): Which class to show output for + normalize (bool): Normalize the dataset + plot (bool): Plot the results (default True) + + Returns: + dictionary of feature name: absolute change in model output + + """ dataset = utils.load_dataset(model_path, dataset_path, normalize) # Convert sample name to index @@ -58,16 +75,16 @@ def get_feature_importance(model_path, dataset_path, sample_name, topn, target_c output_table = dict(zip(feature_names, output)) - fig, ax = plt.subplots() - ax.plot(output[order[-topn:]], range(topn), ls='', marker='o') - ax.set_xlabel(f'Change in model output for class {target_class}') - ax.set_ylabel('Masked feature') - for i in range(topn): - ax.axhline(i, c='k', ls='--', alpha=.2) - ax.set_yticks(range(topn), feature_names[order[-topn:]]) - fig.suptitle(f'Feature importance for {sample_name}') - fig.tight_layout() - - plt.show() + if plot: + fig, ax = plt.subplots() + ax.plot(output[order[-topn:]], range(topn), ls='', marker='o') + ax.set_xlabel(f'Change in model output for class {target_class}') + ax.set_ylabel('Masked feature') + for i in range(topn): + ax.axhline(i, c='k', ls='--', alpha=.2) + ax.set_yticks(range(topn), feature_names[order[-topn:]]) + fig.suptitle(f'Feature importance for {sample_name}') + fig.tight_layout() + plt.show() return output_table diff --git a/dianna/deeprank/single_feature.py b/dianna/deeprank/single_feature.py new file mode 100644 index 00000000..8d280f43 --- /dev/null +++ b/dianna/deeprank/single_feature.py @@ -0,0 +1,2 @@ +def run_single_feature(): + pass \ No newline at end of file From ec3bfbc955726f8cc0669030b2262d158a111c67 Mon Sep 17 00:00:00 2001 From: Leon Oostrum Date: Thu, 19 Oct 2023 15:06:53 +0200 Subject: [PATCH 3/6] Implement function to run DIANNA on deeprank model with single feature --- dianna/deeprank/__init__.py | 4 +- dianna/deeprank/feature_importance.py | 5 +- dianna/deeprank/single_feature.py | 70 +++++++++++++++++++++++- dianna/deeprank/utils.py | 77 +++++++++++++++++++++++++++ 4 files changed, 149 insertions(+), 7 deletions(-) diff --git a/dianna/deeprank/__init__.py b/dianna/deeprank/__init__.py index 3b2eae4c..fa709625 100644 --- a/dianna/deeprank/__init__.py +++ b/dianna/deeprank/__init__.py @@ -5,5 +5,5 @@ raise ImportError("Cannot find deeprank, please install manually or reinstall dianna with " "chemistry support: `pip install dianna[chem]`") -from .feature_importance import get_feature_importance -from .single_feature import run_single_feature +from .feature_importance import get_feature_importance # noqa: E402 +from .single_feature import run_single_feature # noqa: E402 diff --git a/dianna/deeprank/feature_importance.py b/dianna/deeprank/feature_importance.py index 1f232a4a..0d673c3a 100644 --- a/dianna/deeprank/feature_importance.py +++ b/dianna/deeprank/feature_importance.py @@ -18,9 +18,8 @@ @click.option('--topn', type=int, default=10, show_default=True, help='Show this many features in the output') @click.option('--normalize', is_flag=True, show_default=True, default=False, help='Normalize the dataset') def get_feature_importance(model_path, dataset_path, sample_name, topn, target_class, normalize, plot=True): - """ - Run the model, masking the features completely one by one. Then check how much the model output changes. - The feature which, when masked, induces the largest change in the model output is deemed most important + """Run the model, masking the features completely one by one. Then check how much the model output changes. + The feature which, when masked, induces the largest change in the model output is deemed most important. Args: model_path (str): Path to deeprank model state (tar file) diff --git a/dianna/deeprank/single_feature.py b/dianna/deeprank/single_feature.py index 8d280f43..5f48e467 100644 --- a/dianna/deeprank/single_feature.py +++ b/dianna/deeprank/single_feature.py @@ -1,2 +1,68 @@ -def run_single_feature(): - pass \ No newline at end of file +import numpy as np + +from dianna.methods.rise import RISEImage +from dianna.utils import MultiInputWrapper +from dianna.deeprank.model import ModelRunner +from dianna.deeprank import utils + + +def run_single_feature(model_path, dataset_path, sample_name, feature_name, normalize, batch_size=32, **rise_kwargs): + """Run DIANNA on a single feature of a deeprank sample + This feature is spatially masked to find which parts of the protein contribute most to the model output + The other features are not masked and remain present in the model input + + Args: + model_path (str): Path to deeprank model state (tar file) + dataset_path (str): Path to deeprank dataset (hdf5 file) + sample_name (str): Name of sample within dataset to analyse + feature_name (str): Name of feature to run DIANNA on + normalize (bool): Normalize the dataset + batch_size (int): Model batch size when running DIANNA explainer (default 32) + + Any further keyword arguments are given to the initializer of the DIANNA explainer + If not given, this function sets default values + """ + dataset = utils.load_dataset(model_path, dataset_path, normalize) + + # Convert sample name to index + sample_index = None + for idx, item in enumerate(dataset): + if item['mol'][1] == sample_name: + sample_index = idx + break + if sample_index is None: + raise ValueError(f'Could not find sample {sample_name} in dataset') + + sample = dataset[sample_index] + sample_mol_data = sample['mol'] + sample_features = sample['feature'] + + # Convert feature name to index + feature_names = utils.get_feature_names(dataset_path) + try: + feature_idx = np.where(feature_names == feature_name)[0][0] + except IndexError: + raise ValueError(f'Could not find feature {feature_name} in features') + + run_model = ModelRunner(model_path, sample_features.shape, normalize) + + # dianna requires a channels axis, put this in first position as required by MultiInputWrapper + dianna_input = sample_features[feature_idx][None, ...] + axis_labels = ('channels', 'x', 'y', 'z') + static_input = np.delete(sample_features, [feature_idx], axis=0) + + # preprocess function to reconstruct full input from a single feature + reconstructor = MultiInputWrapper(static_input, feature_idx) + + # set defaults for RISE if values are not given + rise_settings = {'n_masks': 512, 'feature_res': 8, 'p_keep': .6} + rise_settings.update(rise_kwargs) + + rise = RISEImage(axis_labels=axis_labels, preprocess_function=reconstructor, **rise_settings) + + gen = utils.generate_interface_center(sample_mol_data) + + heatmaps = rise.explain(run_model, dianna_input, labels=(0, 1), center_generator=gen, + batch_size=batch_size) + + return heatmaps diff --git a/dianna/deeprank/utils.py b/dianna/deeprank/utils.py index 9505141a..236e67e3 100644 --- a/dianna/deeprank/utils.py +++ b/dianna/deeprank/utils.py @@ -1,6 +1,7 @@ from deeprank.learn import DataSet import h5py import numpy as np +from pdb2sql import interface import torch @@ -68,3 +69,79 @@ def get_feature_names(hdf5_path): raise KeyError('Failed to extract feature names from dataset') from err return feature_names + + +def generate_interface_center(mol_data): + """Yield a the indices of a new center draw from a line + between two contact residues + + Args: + mol_data (tuple): hdf5 file name, molecule name + """ + + def get_grid_index(point, grid_points): + """Get grid indices of a point from its xyz coordinates + + Args: + point (np.ndarray): xyz coordinates of the point + grid_points (tuple): (xgrid, ygrid, zgrid) + + Returns: + list: indices of the point in the grid + """ + index = [] + for pt_coord, grid_coord in zip(point, grid_points): + index.append(np.argmin(np.abs(grid_coord - pt_coord))) + return index + + def get_next_point(db, res): + """generate the xyz coordinate of a random center + + Args: + db (pdb2sql.interface): an interface instance created from the molecule + res (dict): a dictionar of interface residue obtained via .get_contact_residues() + + Returns: + np.ndarray: xyz coordinate of the new center + """ + resA, resB = res[chains[0]], res[chains[1]] + nresA, nresB = len(resA), len(resB) + + rA = resA[np.random.randint(0, nresA)] + rB = resB[np.random.randint(0, nresB)] + + posA = np.array(db.get('x,y,z', chainID=rA[0], resSeq=rA[1])).mean(axis=0) + posB = np.array(db.get('x,y,z', chainID=rB[0], resSeq=rB[1])).mean(axis=0) + return posA + np.random.rand(3)*(posB-posA) + + # get the hdf5 filename and molecule name + filename, molname = mol_data + if isinstance(filename, (tuple, list)): + filename = filename[0] + + if isinstance(molname, (tuple, list)): + molname = molname[0] + + # get data from the hdf5 file + with h5py.File(filename, 'r') as f5: + mol = f5[molname]['complex'][()] + gridx = f5[molname]['grid_points']['x'][()] + gridy = f5[molname]['grid_points']['y'][()] + gridz = f5[molname]['grid_points']['z'][()] + + # assemble grid data + grid_points = (gridx, gridy, gridz) + + # create the interfance and identify contact residues + db = interface(mol) + chains = db.get_chains() + res = db.get_contact_residues(chain1=chains[0], chain2=chains[1]) + + # get the first center + xyz_center = get_next_point(db, res) + yield get_grid_index(xyz_center, grid_points) + + # get all other centers + while True: + xyz_center = get_next_point(db, res) + yield get_grid_index(xyz_center, grid_points) \ No newline at end of file From e07efe34d9fd70aa7b025dc01bb2db9875cc8a1d Mon Sep 17 00:00:00 2001 From: Leon Oostrum Date: Thu, 19 Oct 2023 15:42:35 +0200 Subject: [PATCH 4/6] Make center_generator work for arbitrary number of dimensions --- dianna/methods/rise.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/dianna/methods/rise.py b/dianna/methods/rise.py index 02c7ecdd..92cea6d1 100644 --- a/dianna/methods/rise.py +++ b/dianna/methods/rise.py @@ -244,9 +244,7 @@ def _calculate_max_class_std(self, p_keep, runner, input_data, n_masks, center_g @staticmethod def _default_center_generator(input_size): while True: - yield (np.random.randint(0, input_size[0]), - np.random.randint(0, input_size[1]), - np.random.randint(0, input_size[2])) + yield np.random.randint(0, input_size) @staticmethod def _map_indices(input_idx, feature_res): @@ -262,7 +260,7 @@ def _generate_masks(self, input_size, p_keep, n_masks, center_generator): """Generates a set of random masks to mask the input data. Args: - input_size (int): Size of a single sample of input data, for images without the channel axis. + input_size (list): Size of a single sample of input data, for images without the channel axis. Returns: The generated masks (np.ndarray) @@ -273,16 +271,20 @@ def _generate_masks(self, input_size, p_keep, n_masks, center_generator): cell_size = np.ceil(np.array(input_size) / self.feature_res) up_size = (self.feature_res + 1) * cell_size - grid = np.random.choice(a=(True, False), size=(n_masks, self.feature_res, self.feature_res, self.feature_res), + grid = np.random.choice(a=(True, False), size=(n_masks, *(self.feature_res, ) * len(input_size)), p=(p_keep, 1 - p_keep)) grid = grid.astype('float32') masks = np.empty((n_masks, *input_size), dtype=np.float32) + print(f'{masks.shape=}') for i in range(n_masks): - (x, y, z) = self._map_indices(next(center_generator), self.feature_res) + coords = self._map_indices(next(center_generator), self.feature_res) + selection = [] + for start, size in zip(coords, input_size): + selection.append(slice(start, start+size)) # Linear upsampling and cropping - masks[i, ...] = _upscale(grid[i], up_size)[x:x + input_size[0], y:y + input_size[1], z:z + input_size[2]] + masks[i, ...] = _upscale(grid[i], up_size)[tuple(selection)] masks = masks.reshape(-1, *input_size, 1) return masks From 4b7504cfde82b4686a94306d8a0342220eb61664 Mon Sep 17 00:00:00 2001 From: Leon Oostrum Date: Thu, 19 Oct 2023 15:51:10 +0200 Subject: [PATCH 5/6] Remove debug print --- dianna/methods/rise.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dianna/methods/rise.py b/dianna/methods/rise.py index 92cea6d1..35a4f479 100644 --- a/dianna/methods/rise.py +++ b/dianna/methods/rise.py @@ -277,7 +277,6 @@ def _generate_masks(self, input_size, p_keep, n_masks, center_generator): masks = np.empty((n_masks, *input_size), dtype=np.float32) - print(f'{masks.shape=}') for i in range(n_masks): coords = self._map_indices(next(center_generator), self.feature_res) selection = [] From 70e82438730737feebc8c0192b25d2044e412a22 Mon Sep 17 00:00:00 2001 From: Leon Oostrum Date: Mon, 23 Oct 2023 10:34:07 +0200 Subject: [PATCH 6/6] Separate get_feature_importance CLI from normal python function as it wouldn't work properly with args vs keyword args; add some comments and more checks as well --- dianna/deeprank/feature_importance.py | 39 ++++++++++++++++----------- setup.cfg | 2 +- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/dianna/deeprank/feature_importance.py b/dianna/deeprank/feature_importance.py index 0d673c3a..6ebca0b7 100644 --- a/dianna/deeprank/feature_importance.py +++ b/dianna/deeprank/feature_importance.py @@ -9,15 +9,8 @@ plt.switch_backend(backend) # because deeprank changes the matplotlib backend on import -@click.command() -@click.option('--model', 'model_path', help='Path to deeprank model state (tar file)', required=True) -@click.option('--dataset', 'dataset_path', help='Path to deeprank dataset (hdf5 file)', required=True) -@click.option('--sample', 'sample_name', help='Name of sample within dataset to analyse', required=True) -@click.option('--target_class', type=int, - help='Which class to show output for [Default: real target class of input sample]') -@click.option('--topn', type=int, default=10, show_default=True, help='Show this many features in the output') -@click.option('--normalize', is_flag=True, show_default=True, default=False, help='Normalize the dataset') -def get_feature_importance(model_path, dataset_path, sample_name, topn, target_class, normalize, plot=True): +def get_feature_importance(model_path, dataset_path, sample_name, topn=10, + target_class=None, normalize=False, plot=False): """Run the model, masking the features completely one by one. Then check how much the model output changes. The feature which, when masked, induces the largest change in the model output is deemed most important. @@ -25,10 +18,10 @@ def get_feature_importance(model_path, dataset_path, sample_name, topn, target_c model_path (str): Path to deeprank model state (tar file) dataset_path (str): Path to deeprank dataset (hdf5 file) sample_name (str): Name of sample within dataset to analyse - topn (int): Show this many features in the output - target_class (int): Which class to show output for - normalize (bool): Normalize the dataset - plot (bool): Plot the results (default True) + topn (int): Show this many features in the output (default 10) + target_class (int): Which class to show output for (default real class of sample) + normalize (bool): Normalize the dataset (default false) + plot (bool): Plot the results (default False) Returns: dictionary of feature name: absolute change in model output @@ -48,7 +41,11 @@ def get_feature_importance(model_path, dataset_path, sample_name, topn, target_c # select target class if target_class is None: - target_class = int(dataset[sample_index]['target']) + try: + target_class = int(dataset[sample_index]['target']) + except KeyError: + raise KeyError("Field 'target' not found in sample. Provide target_class manually if this is " + "not a training sample") print(f'Target class: {target_class}') feature_names = utils.get_feature_names(dataset_path) @@ -61,6 +58,7 @@ def get_feature_importance(model_path, dataset_path, sample_name, topn, target_c orig_model_output = run_model([sample])[0][target_class] + # Create input where the nth input has the nth feature set to all zeroes inputs = np.zeros(shape=(sample.shape[0], *sample.shape)) inputs[:] = sample for idx in range(nfeature): @@ -69,7 +67,6 @@ def get_feature_importance(model_path, dataset_path, sample_name, topn, target_c output = run_model(inputs)[:, target_class] # transform into abs distance from original model output output = np.abs(output - orig_model_output) - order = np.argsort(output) output_table = dict(zip(feature_names, output)) @@ -87,3 +84,15 @@ def get_feature_importance(model_path, dataset_path, sample_name, topn, target_c plt.show() return output_table + + +@click.command() +@click.option('--model', 'model_path', help='Path to deeprank model state (tar file)', required=True) +@click.option('--dataset', 'dataset_path', help='Path to deeprank dataset (hdf5 file)', required=True) +@click.option('--sample', 'sample_name', help='Name of sample within dataset to analyse', required=True) +@click.option('--target_class', type=int, default=None, + help='Which class to show output for [Default: real target class of input sample]') +@click.option('--topn', type=int, default=10, show_default=True, help='Show this many features in the output') +@click.option('--normalize', is_flag=True, show_default=True, default=False, help='Normalize the dataset') +def get_feature_importance_cli(*args, **kwargs): + get_feature_importance(*args, **kwargs, plot=True) diff --git a/setup.cfg b/setup.cfg index 5f8bdd71..4920ad61 100644 --- a/setup.cfg +++ b/setup.cfg @@ -102,7 +102,7 @@ chem = [options.entry_points] console_scripts = - dianna-deeprank-feature-importance = dianna.deeprank.feature_importance:get_feature_importance + dianna-deeprank-feature-importance = dianna.deeprank.feature_importance:get_feature_importance_cli [options.packages.find] include = dianna, dianna.*