Skip to content

Commit

Permalink
added kwargs tests and code to raise warning for extra kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
cpranav93 committed Jan 18, 2024
1 parent 79e2831 commit ce779fd
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
7 changes: 7 additions & 0 deletions dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import importlib
import logging
from . import utils
import warnings

logging.getLogger(__name__).addHandler(logging.NullHandler())

Expand Down Expand Up @@ -77,6 +78,10 @@ def explain_image(model_or_function, input_image, method, labels, **kwargs):
explain_image_kwargs = utils.get_kwargs_applicable_to_function(
explainer.explain, kwargs
)
for key in explain_image_kwargs.keys():
kwargs.pop(key)
if kwargs:
warnings.warn(message = f'Please note the following kwargs are not being used: {kwargs}')
return explainer.explain(
model_or_function, input_image, labels, **explain_image_kwargs
)
Expand Down Expand Up @@ -154,4 +159,6 @@ def _get_explainer(method, kwargs, modality):
method_kwargs = utils.get_kwargs_applicable_to_function(
method_class.__init__, kwargs
)
for key in method_kwargs.keys():
kwargs.pop(key)
return method_class(**method_kwargs)
60 changes: 60 additions & 0 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from unittest import TestCase
import pytest
import dianna
from tests.test_onnx_runner import generate_data
import numpy as np

class initialize_method(TestCase):

def test_lime_image_correct_kwargs(self):
model_filename = 'tests/test_data/mnist_model.onnx'
input_data = generate_data(batch_size=1)[0].astype(np.float32)
axis_labels = ('channels', 'y', 'x')
labels = [1]

dianna.explain_image(model_filename,
input_data,
method='LIME',
labels=labels,
kernel=None,
kernel_width=25,
verbose=False,
feature_selection='auto',
random_state=None,
axis_labels=axis_labels,
preprocess_function=None,
top_labels=None,
num_features=10,
num_samples=5000,
return_masks=True,
positive_only=False,
hide_rest=True,
)

def test_lime_image_extra_kwarg(self):
model_filename = 'tests/test_data/mnist_model.onnx'
input_data = generate_data(batch_size=1)[0].astype(np.float32)
axis_labels = ('channels', 'y', 'x')
labels = [1]

with self.assertWarns(Warning):
dianna.explain_image(model_filename,
input_data,
method='LIME',
labels=labels,
kernel=None,
kernel_width=25,
verbose=False,
feature_selection='auto',
random_state=None,
axis_labels=axis_labels,
preprocess_function=None,
top_labels=None,
num_features=10,
num_samples=5000,
return_masks=True,
positive_only=False,
hide_rest=True,
extra_kwarg=None
)

0 comments on commit ce779fd

Please sign in to comment.