Skip to content

Commit

Permalink
Merge pull request #674 from dianna-ai/shap_tabular
Browse files Browse the repository at this point in the history
645 Implement SHAP for tabular data
  • Loading branch information
Yang authored Jan 11, 2024
2 parents b4793b2 + 080d33b commit 1fb3a2a
Show file tree
Hide file tree
Showing 7 changed files with 835 additions and 4 deletions.
84 changes: 84 additions & 0 deletions dianna/methods/kernelshap_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import shap
from shap import KernelExplainer
from dianna import utils


class KERNELSHAPTabular:
"""Wrapper around the SHAP Kernel explainer for tabular data."""

def __init__(
self,
training_data: np.array,
mode: str = "classification",
feature_names: List[int] = None,
training_data_kmeans: Optional[int] = None,
) -> None:
"""Initializer of KERNELSHAPTabular.
Training data must be provided for the explainer to estimate the expected
values.
More information can be found in the API guide:
https://github.com/shap/shap/blob/master/shap/explainers/_kernel.py
Arguments:
training_data (np.array): training data, which should be numpy 2d array
mode (str, optional): "classification" or "regression"
feature_names (list(str), optional): list of names corresponding to the columns
in the training data.
training_data_kmeans(int, optional): summarize the whole training set with
weighted kmeans
"""
if training_data_kmeans:
self.training_data = shap.kmeans(training_data, training_data_kmeans)
else:
self.training_data = training_data
self.feature_names = feature_names
self.mode = mode
self.explainer: KernelExplainer

def explain(
self,
model_or_function: Union[str, callable],
input_tabular: np.array,
link: str = "identity",
**kwargs,
) -> np.array:
"""Run the KernelSHAP explainer.
Args:
model_or_function (callable or str): The function that runs the model to be explained
or the path to a ONNX model on disk.
input_tabular (np.ndarray): Data to be explained.
link (str): A generalized linear model link to connect the feature importance values
to the model. Must be either "identity" or "logit".
kwargs: These parameters are passed on
Other keyword arguments: see the documentation for KernelExplainer:
https://github.com/shap/shap/blob/master/shap/explainers/_kernel.py
Returns:
explanation: An Explanation object containing the KernelExplainer explanations
for each class.
"""
init_instance_kwargs = utils.get_kwargs_applicable_to_function(
KernelExplainer, kwargs
)
self.explainer = KernelExplainer(
model_or_function, self.training_data, link, **init_instance_kwargs
)

explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
self.explainer.shap_values, kwargs
)

saliency = self.explainer.shap_values(input_tabular, **explain_instance_kwargs)

if self.mode == 'regression':
return saliency[0]

return saliency
4 changes: 2 additions & 2 deletions dianna/methods/lime_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ def explain(
**explain_instance_kwargs,
)

if self.mode == "regression":
if self.mode == 'regression':
local_exp = sorted(explanation.local_exp[1])
saliency = [i[1] for i in local_exp]

elif self.mode == "classification":
elif self.mode == 'classification':
# extract scores from lime explainer
saliency = []
for i in range(self.top_labels):
Expand Down
35 changes: 35 additions & 0 deletions tests/methods/test_shap_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Test LIME tabular method."""
from unittest import TestCase
import numpy as np
import dianna
from dianna.methods.kernelshap_tabular import KERNELSHAPTabular
from tests.utils import run_model


class LIMEOnTabular(TestCase):
"""Suite of LIME tests for the tabular case."""

def test_shap_tabular_classification_correct_output_shape(self):
"""Test whether the output of explainer has the correct shape."""
training_data = np.random.random((10, 2))
input_data = np.random.random(2)
feature_names = ["feature_1", "feature_2"]
explainer = KERNELSHAPTabular(training_data,
mode ='classification',
feature_names=feature_names,)
exp = explainer.explain(
run_model,
input_data,
)
assert len(exp[0]) == len(feature_names)

def test_shap_tabular_regression_correct_output_shape(self):
"""Test whether the output of explainer has the correct length."""
training_data = np.random.random((10, 2))
input_data = np.random.random(2)
feature_names = ["feature_1", "feature_2"]
exp = dianna.explain_tabular(run_model, input_tabular=input_data, method='kernelshap',
mode ='regression', training_data = training_data,
training_data_kmeans = 2, feature_names=feature_names)

assert len(exp) == len(feature_names)
Loading

0 comments on commit 1fb3a2a

Please sign in to comment.