Skip to content

Commit

Permalink
Test Simple Layer Rules
Browse files Browse the repository at this point in the history
- added simple module fixtures, for simple, element-wise, parameter-less
  modules like ReLU
- added simple data fixtures for the simple layers
- made batch-size a pytest CLI option, defaulting to 4
- added tests for rules that only apply to simple layers
  • Loading branch information
chr5tphr committed Feb 9, 2022
1 parent 7997930 commit ab8b510
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 31 deletions.
48 changes: 46 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@
from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d


def pytest_addoption(parser):
'''Add options to pytest.'''
parser.addoption(
'--batchsize',
default=4,
help='Batch-size for generated samples.'
)


def pytest_generate_tests(metafunc):
'''Generate test fixture values based on CLI options.'''
if 'batchsize' in metafunc.fixturenames:
metafunc.parametrize('batchsize', [metafunc.config.getoption('batchsize')], scope='session')


def prodict(**kwargs):
'''Create a dictionary with values which are the cartesian product of the input keyword arguments.'''
return [dict(zip(kwargs, val)) for val in product(*kwargs.values())]
Expand All @@ -30,6 +45,23 @@ def rng(request):
return torch.manual_seed(request.param)


@pytest.fixture(
scope='session',
params=[
(torch.nn.ReLU, {}),
(torch.nn.Softmax, dict(dim=1)),
(torch.nn.Tanh, {}),
(torch.nn.Sigmoid, {}),
(torch.nn.Softplus, dict(beta=1)),
],
ids=lambda param: param[0].__name__
)
def module_simple(rng, request):
'''Fixture for simple modules.'''
module_type, kwargs = request.param
return module_type(**kwargs).to(torch.float64).eval()


@pytest.fixture(
scope='session',
params=[
Expand Down Expand Up @@ -83,9 +115,9 @@ def module_batchnorm(module_linear):


@pytest.fixture(scope='session')
def data_input(rng, module_linear):
def data_linear(rng, batchsize, module_linear):
'''Fixture to create data for a linear module, given an RNG.'''
shape = (4,)
shape = (batchsize,)
setups = [
(Conv1d, 1, 1),
(ConvTranspose1d, 0, 1),
Expand All @@ -102,3 +134,15 @@ def data_input(rng, module_linear):
shape += (module_linear.weight.shape[dim],) + (4,) * ndims

return torch.empty(*shape, dtype=torch.float64).normal_(generator=rng)


@pytest.fixture(scope='session', params=[
(16,),
(4,),
(4, 4),
(4, 4, 4),
])
def data_simple(request, rng, batchsize):
'''Fixture to create data for a linear module, given an RNG.'''
shape = (batchsize,) + request.param
return torch.empty(*shape, dtype=torch.float64).normal_(generator=rng)
8 changes: 4 additions & 4 deletions tests/test_canonizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
from zennit.canonizers import SequentialMergeBatchNorm


def test_merge_batchnorm_consistency(module_linear, module_batchnorm, data_input):
def test_merge_batchnorm_consistency(module_linear, module_batchnorm, data_linear):
'''Test whether the output of the merged batchnorm is consistent with its original output.'''
output_linear_before = module_linear(data_input)
output_linear_before = module_linear(data_linear)
output_batchnorm_before = module_batchnorm(output_linear_before)
canonizer = SequentialMergeBatchNorm()

try:
canonizer.register((module_linear,), module_batchnorm)
output_linear_canonizer = module_linear(data_input)
output_linear_canonizer = module_linear(data_linear)
output_batchnorm_canonizer = module_batchnorm(output_linear_canonizer)
finally:
canonizer.remove()

output_linear_after = module_linear(data_input)
output_linear_after = module_linear(data_linear)
output_batchnorm_after = module_batchnorm(output_linear_after)

assert all(torch.allclose(left, right, atol=1e-5) for left, right in [
Expand Down
101 changes: 76 additions & 25 deletions tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import torch
from zennit.rules import Epsilon, ZPlus, AlphaBeta, Gamma, ZBox, Norm, WSquare, Flat
from zennit.rules import Pass, ReLUDeconvNet, ReLUGuidedBackprop


def stabilize(input, epsilon=1e-6):
Expand All @@ -22,14 +23,15 @@ def as_matrix(module_linear, input, output):
return weight, bias


RULEPAIRS = []
RULES_LINEAR = []
RULES_SIMPLE = []


def replicates(replicated_func, **kwargs):
def replicates(target_list, replicated_func, **kwargs):
'''Decorator to indicate a replication of a function for testing.'''
def wrapper(func):
'''Append to ``RULEPAIRS`` as partial, given ``kwargs``.'''
RULEPAIRS.append(
'''Append to ``RULES_LINEAR`` as partial, given ``kwargs``.'''
target_list.append(
pytest.param(
(partial(replicated_func, **kwargs), partial(func, **kwargs)),
id=replicated_func.__name__
Expand Down Expand Up @@ -68,16 +70,31 @@ def wrapped(module_linear, input, output, **kwargs):
return wrapped


@replicates(Epsilon, epsilon=1e-6)
@replicates(Epsilon, epsilon=1.0)
@replicates(Norm)
def with_grad(func):
'''Decorator to wrap function such that the gradient is computed and passed to the function instead of module.'''
@wraps(func)
def wrapped(module, input, output, **kwargs):
'''Get gradient and pass along input, output and keyword arguments to func.'''
gradient, = torch.autograd.grad(module(input), input, output)
return func(
gradient,
input,
output,
**kwargs
)
return wrapped


@replicates(RULES_LINEAR, Epsilon, epsilon=1e-6)
@replicates(RULES_LINEAR, Epsilon, epsilon=1.0)
@replicates(RULES_LINEAR, Norm)
@matrix_form
def rule_epsilon(weight, bias, input, relevance, epsilon=1e-6):
'''Replicates the Epsilon rule.'''
return input * ((relevance / stabilize(input @ weight.t() + bias, epsilon)) @ weight)


@replicates(ZPlus)
@replicates(RULES_LINEAR, ZPlus)
@matrix_form
def rule_zplus(weight, bias, input, relevance):
'''Replicates the ZPlus rule.'''
Expand All @@ -90,8 +107,8 @@ def rule_zplus(weight, bias, input, relevance):
return xplus * (rfac @ wplus) + xminus * (rfac @ wminus)


@replicates(Gamma, gamma=0.25)
@replicates(Gamma, gamma=0.5)
@replicates(RULES_LINEAR, Gamma, gamma=0.25)
@replicates(RULES_LINEAR, Gamma, gamma=0.5)
@matrix_form
def rule_gamma(weight, bias, input, relevance, gamma):
'''Replicates the Gamma rule.'''
Expand All @@ -100,8 +117,8 @@ def rule_gamma(weight, bias, input, relevance, gamma):
return input * ((relevance / stabilize(input @ wgamma.t() + bgamma)) @ wgamma)


@replicates(AlphaBeta, alpha=2.0, beta=1.0)
@replicates(AlphaBeta, alpha=1.0, beta=0.0)
@replicates(RULES_LINEAR, AlphaBeta, alpha=2.0, beta=1.0)
@replicates(RULES_LINEAR, AlphaBeta, alpha=1.0, beta=0.0)
@matrix_form
def rule_alpha_beta(weight, bias, input, relevance, alpha, beta):
'''Replicates the AlphaBeta rule.'''
Expand All @@ -118,7 +135,7 @@ def rule_alpha_beta(weight, bias, input, relevance, alpha, beta):
return alpha * result_alpha - beta * result_beta


@replicates(ZBox, low=-3.0, high=3.0)
@replicates(RULES_LINEAR, ZBox, low=-3.0, high=3.0)
@matrix_form
def rule_zbox(weight, bias, input, relevance, low, high):
'''Replicates the ZBox rule.'''
Expand All @@ -131,7 +148,7 @@ def rule_zbox(weight, bias, input, relevance, low, high):
return input * (rfac @ weight) - low * (rfac @ wplus) - high * (rfac @ wminus)


@replicates(WSquare)
@replicates(RULES_LINEAR, WSquare)
@matrix_form
def rule_wsquare(weight, bias, input, relevance):
'''Replicates the WSquare rule.'''
Expand All @@ -141,7 +158,7 @@ def rule_wsquare(weight, bias, input, relevance):
return rfac @ wsquare


@replicates(Flat)
@replicates(RULES_LINEAR, Flat)
@flat_module_params
@matrix_form
def rule_flat(wflat, bias, input, relevance):
Expand All @@ -151,25 +168,59 @@ def rule_flat(wflat, bias, input, relevance):
return rfac @ wflat


@pytest.fixture(scope='session', params=RULEPAIRS)
def rule_pair(request):
'''Fixture to supply ``RULEPAIRS``.'''
@replicates(RULES_SIMPLE, Pass)
def rule_pass(module, input, relevance):
'''Replicates the Pass rule.'''
return relevance


@replicates(RULES_SIMPLE, ReLUDeconvNet)
def rule_relu_deconvnet(module, input, relevance):
'''Replicates the ReLUDeconvNet rule.'''
return relevance.clamp(min=0)


@replicates(RULES_SIMPLE, ReLUGuidedBackprop)
@with_grad
def rule_relu_guidedbackprop(gradient, input, relevance):
'''Replicates the ReLUGuidedBackprop rule.'''
return gradient * (relevance > 0.)


@pytest.fixture(scope='session', params=RULES_LINEAR)
def rule_pair_linear(request):
'''Fixture to supply ``RULES_LINEAR``.'''
return request.param


def test_linear_rule(module_linear, data_input, rule_pair):
'''Test whether replicated and original implementations of rules for linear layers agree.'''
@pytest.fixture(scope='session', params=RULES_SIMPLE)
def rule_pair_simple(request):
'''Fixture to supply ``RULES_SIMPLE``.'''
return request.param


def compare_rule_pair(module, data, rule_pair):
'''Compare rules with their replicated versions.'''
rule_hook, rule_replicated = rule_pair

input = data_input.clone().requires_grad_()
handle = rule_hook().register(module_linear)
input = data.clone().requires_grad_()
handle = rule_hook().register(module)
try:
output = module_linear(input)
output = module(input)
relevance_hook, = torch.autograd.grad(output, input, grad_outputs=output)
finally:
handle.remove()

with torch.no_grad():
relevance_replicated = rule_replicated(module_linear, input, output)
relevance_replicated = rule_replicated(module, input, output)

assert torch.allclose(relevance_hook, relevance_replicated, atol=1e-5)


def test_linear_rule(module_linear, data_linear, rule_pair_linear):
'''Test whether replicated and original implementations of rules for linear layers agree.'''
compare_rule_pair(module_linear, data_linear, rule_pair_linear)


def test_simple_rule(module_simple, data_simple, rule_pair_simple):
'''Test whether replicated and original implementations of rules for simple layers agree.'''
compare_rule_pair(module_simple, data_simple, rule_pair_simple)

0 comments on commit ab8b510

Please sign in to comment.