Skip to content

Commit

Permalink
Merge branch 'dianna_deeprank' of https://github.com/dianna-ai/dianna
Browse files Browse the repository at this point in the history
…into dianna_deeprank
  • Loading branch information
laurasootes committed Oct 23, 2023
2 parents 0f9f9dc + 70e8243 commit c1823d7
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 30 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,9 @@ venv
venv3

.python-version

# deeprank/chemistry files
*.pdb
*.hdf5
*.cube
*.pckl
3 changes: 3 additions & 0 deletions dianna/deeprank/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
if importlib.util.find_spec('deeprank') is None:
raise ImportError("Cannot find deeprank, please install manually or reinstall dianna with "
"chemistry support: `pip install dianna[chem]`")

from .feature_importance import get_feature_importance # noqa: E402
from .single_feature import run_single_feature # noqa: E402
69 changes: 47 additions & 22 deletions dianna/deeprank/feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,24 @@
plt.switch_backend(backend) # because deeprank changes the matplotlib backend on import


@click.command()
@click.option('--model', 'model_path', help='Path to deeprank model state', required=True)
@click.option('--dataset', 'dataset_path', help='Path to deeprank dataset', required=True)
@click.option('--sample', 'sample_name', help='Name of sample within dataset to analyse', required=True)
@click.option('--target_class', type=int,
help='Which class to show output for [Default: real target class of input sample]')
@click.option('--topn', type=int, default=10, show_default=True, help='Show this many features in the output')
@click.option('--normalize', is_flag=True, show_default=True, default=False, help='Normalize the dataset')
def get_feature_importance(model_path, dataset_path, sample_name, topn, target_class, normalize):
def get_feature_importance(model_path, dataset_path, sample_name, topn=10,
target_class=None, normalize=False, plot=False):
"""Run the model, masking the features completely one by one. Then check how much the model output changes.
The feature which, when masked, induces the largest change in the model output is deemed most important.
Args:
model_path (str): Path to deeprank model state (tar file)
dataset_path (str): Path to deeprank dataset (hdf5 file)
sample_name (str): Name of sample within dataset to analyse
topn (int): Show this many features in the output (default 10)
target_class (int): Which class to show output for (default real class of sample)
normalize (bool): Normalize the dataset (default false)
plot (bool): Plot the results (default False)
Returns:
dictionary of feature name: absolute change in model output
"""
dataset = utils.load_dataset(model_path, dataset_path, normalize)

# Convert sample name to index
Expand All @@ -32,7 +41,11 @@ def get_feature_importance(model_path, dataset_path, sample_name, topn, target_c

# select target class
if target_class is None:
target_class = int(dataset[sample_index]['target'])
try:
target_class = int(dataset[sample_index]['target'])
except KeyError:
raise KeyError("Field 'target' not found in sample. Provide target_class manually if this is "
"not a training sample")
print(f'Target class: {target_class}')

feature_names = utils.get_feature_names(dataset_path)
Expand All @@ -45,6 +58,7 @@ def get_feature_importance(model_path, dataset_path, sample_name, topn, target_c

orig_model_output = run_model([sample])[0][target_class]

# Create input where the nth input has the nth feature set to all zeroes
inputs = np.zeros(shape=(sample.shape[0], *sample.shape))
inputs[:] = sample
for idx in range(nfeature):
Expand All @@ -53,21 +67,32 @@ def get_feature_importance(model_path, dataset_path, sample_name, topn, target_c
output = run_model(inputs)[:, target_class]
# transform into abs distance from original model output
output = np.abs(output - orig_model_output)

order = np.argsort(output)

output_table = dict(zip(feature_names, output))

fig, ax = plt.subplots()
ax.plot(output[order[-topn:]], range(topn), ls='', marker='o')
ax.set_xlabel(f'Change in model output for class {target_class}')
ax.set_ylabel('Masked feature')
for i in range(topn):
ax.axhline(i, c='k', ls='--', alpha=.2)
ax.set_yticks(range(topn), feature_names[order[-topn:]])
fig.suptitle(f'Feature importance for {sample_name}')
fig.tight_layout()

plt.show()
if plot:
fig, ax = plt.subplots()
ax.plot(output[order[-topn:]], range(topn), ls='', marker='o')
ax.set_xlabel(f'Change in model output for class {target_class}')
ax.set_ylabel('Masked feature')
for i in range(topn):
ax.axhline(i, c='k', ls='--', alpha=.2)
ax.set_yticks(range(topn), feature_names[order[-topn:]])
fig.suptitle(f'Feature importance for {sample_name}')
fig.tight_layout()
plt.show()

return output_table


@click.command()
@click.option('--model', 'model_path', help='Path to deeprank model state (tar file)', required=True)
@click.option('--dataset', 'dataset_path', help='Path to deeprank dataset (hdf5 file)', required=True)
@click.option('--sample', 'sample_name', help='Name of sample within dataset to analyse', required=True)
@click.option('--target_class', type=int, default=None,
help='Which class to show output for [Default: real target class of input sample]')
@click.option('--topn', type=int, default=10, show_default=True, help='Show this many features in the output')
@click.option('--normalize', is_flag=True, show_default=True, default=False, help='Normalize the dataset')
def get_feature_importance_cli(*args, **kwargs):
get_feature_importance(*args, **kwargs, plot=True)
68 changes: 68 additions & 0 deletions dianna/deeprank/single_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np

from dianna.methods.rise import RISEImage
from dianna.utils import MultiInputWrapper
from dianna.deeprank.model import ModelRunner
from dianna.deeprank import utils


def run_single_feature(model_path, dataset_path, sample_name, feature_name, normalize, batch_size=32, **rise_kwargs):
"""Run DIANNA on a single feature of a deeprank sample
This feature is spatially masked to find which parts of the protein contribute most to the model output
The other features are not masked and remain present in the model input
Args:
model_path (str): Path to deeprank model state (tar file)
dataset_path (str): Path to deeprank dataset (hdf5 file)
sample_name (str): Name of sample within dataset to analyse
feature_name (str): Name of feature to run DIANNA on
normalize (bool): Normalize the dataset
batch_size (int): Model batch size when running DIANNA explainer (default 32)
Any further keyword arguments are given to the initializer of the DIANNA explainer
If not given, this function sets default values
"""
dataset = utils.load_dataset(model_path, dataset_path, normalize)

# Convert sample name to index
sample_index = None
for idx, item in enumerate(dataset):
if item['mol'][1] == sample_name:
sample_index = idx
break
if sample_index is None:
raise ValueError(f'Could not find sample {sample_name} in dataset')

sample = dataset[sample_index]
sample_mol_data = sample['mol']
sample_features = sample['feature']

# Convert feature name to index
feature_names = utils.get_feature_names(dataset_path)
try:
feature_idx = np.where(feature_names == feature_name)[0][0]
except IndexError:
raise ValueError(f'Could not find feature {feature_name} in features')

run_model = ModelRunner(model_path, sample_features.shape, normalize)

# dianna requires a channels axis, put this in first position as required by MultiInputWrapper
dianna_input = sample_features[feature_idx][None, ...]
axis_labels = ('channels', 'x', 'y', 'z')
static_input = np.delete(sample_features, [feature_idx], axis=0)

# preprocess function to reconstruct full input from a single feature
reconstructor = MultiInputWrapper(static_input, feature_idx)

# set defaults for RISE if values are not given
rise_settings = {'n_masks': 512, 'feature_res': 8, 'p_keep': .6}
rise_settings.update(rise_kwargs)

rise = RISEImage(axis_labels=axis_labels, preprocess_function=reconstructor, **rise_settings)

gen = utils.generate_interface_center(sample_mol_data)

heatmaps = rise.explain(run_model, dianna_input, labels=(0, 1), center_generator=gen,
batch_size=batch_size)

return heatmaps
77 changes: 77 additions & 0 deletions dianna/deeprank/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from deeprank.learn import DataSet
import h5py
import numpy as np
from pdb2sql import interface
import torch


Expand Down Expand Up @@ -68,3 +69,79 @@ def get_feature_names(hdf5_path):
raise KeyError('Failed to extract feature names from dataset') from err

return feature_names


def generate_interface_center(mol_data):
"""Yield a the indices of a new center draw from a line
between two contact residues
Args:
mol_data (tuple): hdf5 file name, molecule name
"""

def get_grid_index(point, grid_points):
"""Get grid indices of a point from its xyz coordinates
Args:
point (np.ndarray): xyz coordinates of the point
grid_points (tuple): (xgrid, ygrid, zgrid)
Returns:
list: indices of the point in the grid
"""
index = []
for pt_coord, grid_coord in zip(point, grid_points):
index.append(np.argmin(np.abs(grid_coord - pt_coord)))
return index

def get_next_point(db, res):
"""generate the xyz coordinate of a random center
Args:
db (pdb2sql.interface): an interface instance created from the molecule
res (dict): a dictionar of interface residue obtained via .get_contact_residues()
Returns:
np.ndarray: xyz coordinate of the new center
"""
resA, resB = res[chains[0]], res[chains[1]]
nresA, nresB = len(resA), len(resB)

rA = resA[np.random.randint(0, nresA)]
rB = resB[np.random.randint(0, nresB)]

posA = np.array(db.get('x,y,z', chainID=rA[0], resSeq=rA[1])).mean(axis=0)
posB = np.array(db.get('x,y,z', chainID=rB[0], resSeq=rB[1])).mean(axis=0)
return posA + np.random.rand(3)*(posB-posA)

# get the hdf5 filename and molecule name
filename, molname = mol_data
if isinstance(filename, (tuple, list)):
filename = filename[0]

if isinstance(molname, (tuple, list)):
molname = molname[0]

# get data from the hdf5 file
with h5py.File(filename, 'r') as f5:
mol = f5[molname]['complex'][()]
gridx = f5[molname]['grid_points']['x'][()]
gridy = f5[molname]['grid_points']['y'][()]
gridz = f5[molname]['grid_points']['z'][()]

# assemble grid data
grid_points = (gridx, gridy, gridz)

# create the interfance and identify contact residues
db = interface(mol)
chains = db.get_chains()
res = db.get_contact_residues(chain1=chains[0], chain2=chains[1])

# get the first center
xyz_center = get_next_point(db, res)
yield get_grid_index(xyz_center, grid_points)

# get all other centers
while True:
xyz_center = get_next_point(db, res)
yield get_grid_index(xyz_center, grid_points)
15 changes: 8 additions & 7 deletions dianna/methods/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,7 @@ def _calculate_max_class_std(self, p_keep, runner, input_data, n_masks, center_g
@staticmethod
def _default_center_generator(input_size):
while True:
yield (np.random.randint(0, input_size[0]),
np.random.randint(0, input_size[1]),
np.random.randint(0, input_size[2]))
yield np.random.randint(0, input_size)

@staticmethod
def _map_indices(input_idx, feature_res):
Expand All @@ -262,7 +260,7 @@ def _generate_masks(self, input_size, p_keep, n_masks, center_generator):
"""Generates a set of random masks to mask the input data.
Args:
input_size (int): Size of a single sample of input data, for images without the channel axis.
input_size (list): Size of a single sample of input data, for images without the channel axis.
Returns:
The generated masks (np.ndarray)
Expand All @@ -273,16 +271,19 @@ def _generate_masks(self, input_size, p_keep, n_masks, center_generator):
cell_size = np.ceil(np.array(input_size) / self.feature_res)
up_size = (self.feature_res + 1) * cell_size

grid = np.random.choice(a=(True, False), size=(n_masks, self.feature_res, self.feature_res, self.feature_res),
grid = np.random.choice(a=(True, False), size=(n_masks, *(self.feature_res, ) * len(input_size)),
p=(p_keep, 1 - p_keep))
grid = grid.astype('float32')

masks = np.empty((n_masks, *input_size), dtype=np.float32)

for i in range(n_masks):
(x, y, z) = self._map_indices(next(center_generator), self.feature_res)
coords = self._map_indices(next(center_generator), self.feature_res)
selection = []
for start, size in zip(coords, input_size):
selection.append(slice(start, start+size))
# Linear upsampling and cropping
masks[i, ...] = _upscale(grid[i], up_size)[x:x + input_size[0], y:y + input_size[1], z:z + input_size[2]]
masks[i, ...] = _upscale(grid[i], up_size)[tuple(selection)]
masks = masks.reshape(-1, *input_size, 1)
return masks

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ chem =

[options.entry_points]
console_scripts =
dianna-deeprank-feature-importance = dianna.deeprank.feature_importance:get_feature_importance
dianna-deeprank-feature-importance = dianna.deeprank.feature_importance:get_feature_importance_cli

[options.packages.find]
include = dianna, dianna.*
Expand Down

0 comments on commit c1823d7

Please sign in to comment.