Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

645 Implement SHAP for tabular data #674

Merged
merged 10 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions dianna/methods/kernelshap_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import shap
from shap import KernelExplainer
from dianna import utils


class KERNELSHAPTabular:
"""Wrapper around the SHAP Kernel explainer for tabular data."""

def __init__(
self,
training_data: np.array,
mode: str = "classification",
feature_names: List[int] = None,
training_data_kmeans: Optional[int] = None,
) -> None:
"""Initializer of KERNELSHAPTabular.

Training data must be provided for the explainer to estimate the expected
values.

More information can be found in the API guide:
https://github.com/shap/shap/blob/master/shap/explainers/_kernel.py

Arguments:
training_data (np.array): training data, which should be numpy 2d array
mode (str, optional): "classification" or "regression"
feature_names (list(str), optional): list of names corresponding to the columns
in the training data.
training_data_kmeans(int, optional): summarize the whole training set with
weighted kmeans
"""
if training_data_kmeans:
self.training_data = shap.kmeans(training_data, training_data_kmeans)
else:
self.training_data = training_data
self.feature_names = feature_names
self.mode = mode
self.explainer: KernelExplainer

def explain(
self,
model_or_function: Union[str, callable],
input_tabular: np.array,
link: str = "identity",
**kwargs,
) -> np.array:
"""Run the KernelSHAP explainer.

Args:
model_or_function (callable or str): The function that runs the model to be explained
or the path to a ONNX model on disk.
input_tabular (np.ndarray): Data to be explained.
link (str): A generalized linear model link to connect the feature importance values
to the model. Must be either "identity" or "logit".
kwargs: These parameters are passed on

Other keyword arguments: see the documentation for KernelExplainer:
https://github.com/shap/shap/blob/master/shap/explainers/_kernel.py

Returns:
explanation: An Explanation object containing the KernelExplainer explanations
for each class.
"""
init_instance_kwargs = utils.get_kwargs_applicable_to_function(
KernelExplainer, kwargs
)
self.explainer = KernelExplainer(
model_or_function, self.training_data, link, **init_instance_kwargs
)

explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
self.explainer.shap_values, kwargs
)

saliency = self.explainer.shap_values(input_tabular, **explain_instance_kwargs)

if self.mode == 'regression':
return saliency[0]

return saliency
4 changes: 2 additions & 2 deletions dianna/methods/lime_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ def explain(
**explain_instance_kwargs,
)

if self.mode == "regression":
if self.mode == 'regression':
local_exp = sorted(explanation.local_exp[1])
saliency = [i[1] for i in local_exp]

elif self.mode == "classification":
elif self.mode == 'classification':
# extract scores from lime explainer
saliency = []
for i in range(self.top_labels):
Expand Down
35 changes: 35 additions & 0 deletions tests/methods/test_shap_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Test LIME tabular method."""
from unittest import TestCase
import numpy as np
import dianna
from dianna.methods.kernelshap_tabular import KERNELSHAPTabular
from tests.utils import run_model


class LIMEOnTabular(TestCase):
"""Suite of LIME tests for the tabular case."""

def test_shap_tabular_classification_correct_output_shape(self):
"""Test the output of explainer."""
geek-yang marked this conversation as resolved.
Show resolved Hide resolved
training_data = np.random.random((10, 2))
input_data = np.random.random(2)
feature_names = ["feature_1", "feature_2"]
explainer = KERNELSHAPTabular(training_data,
mode ='classification',
feature_names=feature_names,)
exp = explainer.explain(
run_model,
input_data,
)
assert len(exp[0]) == len(feature_names)

def test_shap_tabular_regression_correct_output_shape(self):
"""Test the output of explainer."""
geek-yang marked this conversation as resolved.
Show resolved Hide resolved
training_data = np.random.random((10, 2))
input_data = np.random.random(2)
feature_names = ["feature_1", "feature_2"]
exp = dianna.explain_tabular(run_model, input_tabular=input_data, method='kernelshap',
mode ='regression', training_data = training_data,
training_data_kmeans = 2, feature_names=feature_names)

assert len(exp) == len(feature_names)
Loading
Loading