-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'dianna_deeprank' of https://github.com/dianna-ai/dianna …
…into dianna_deeprank
- Loading branch information
Showing
7 changed files
with
210 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,3 +34,9 @@ venv | |
venv3 | ||
|
||
.python-version | ||
|
||
# deeprank/chemistry files | ||
*.pdb | ||
*.hdf5 | ||
*.cube | ||
*.pckl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import numpy as np | ||
|
||
from dianna.methods.rise import RISEImage | ||
from dianna.utils import MultiInputWrapper | ||
from dianna.deeprank.model import ModelRunner | ||
from dianna.deeprank import utils | ||
|
||
|
||
def run_single_feature(model_path, dataset_path, sample_name, feature_name, normalize, batch_size=32, **rise_kwargs): | ||
"""Run DIANNA on a single feature of a deeprank sample | ||
This feature is spatially masked to find which parts of the protein contribute most to the model output | ||
The other features are not masked and remain present in the model input | ||
Args: | ||
model_path (str): Path to deeprank model state (tar file) | ||
dataset_path (str): Path to deeprank dataset (hdf5 file) | ||
sample_name (str): Name of sample within dataset to analyse | ||
feature_name (str): Name of feature to run DIANNA on | ||
normalize (bool): Normalize the dataset | ||
batch_size (int): Model batch size when running DIANNA explainer (default 32) | ||
Any further keyword arguments are given to the initializer of the DIANNA explainer | ||
If not given, this function sets default values | ||
""" | ||
dataset = utils.load_dataset(model_path, dataset_path, normalize) | ||
|
||
# Convert sample name to index | ||
sample_index = None | ||
for idx, item in enumerate(dataset): | ||
if item['mol'][1] == sample_name: | ||
sample_index = idx | ||
break | ||
if sample_index is None: | ||
raise ValueError(f'Could not find sample {sample_name} in dataset') | ||
|
||
sample = dataset[sample_index] | ||
sample_mol_data = sample['mol'] | ||
sample_features = sample['feature'] | ||
|
||
# Convert feature name to index | ||
feature_names = utils.get_feature_names(dataset_path) | ||
try: | ||
feature_idx = np.where(feature_names == feature_name)[0][0] | ||
except IndexError: | ||
raise ValueError(f'Could not find feature {feature_name} in features') | ||
|
||
run_model = ModelRunner(model_path, sample_features.shape, normalize) | ||
|
||
# dianna requires a channels axis, put this in first position as required by MultiInputWrapper | ||
dianna_input = sample_features[feature_idx][None, ...] | ||
axis_labels = ('channels', 'x', 'y', 'z') | ||
static_input = np.delete(sample_features, [feature_idx], axis=0) | ||
|
||
# preprocess function to reconstruct full input from a single feature | ||
reconstructor = MultiInputWrapper(static_input, feature_idx) | ||
|
||
# set defaults for RISE if values are not given | ||
rise_settings = {'n_masks': 512, 'feature_res': 8, 'p_keep': .6} | ||
rise_settings.update(rise_kwargs) | ||
|
||
rise = RISEImage(axis_labels=axis_labels, preprocess_function=reconstructor, **rise_settings) | ||
|
||
gen = utils.generate_interface_center(sample_mol_data) | ||
|
||
heatmaps = rise.explain(run_model, dianna_input, labels=(0, 1), center_generator=gen, | ||
batch_size=batch_size) | ||
|
||
return heatmaps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters