diff --git a/dlordinal/layers/clm.py b/dlordinal/layers/clm.py index acbe4f5..1b87952 100644 --- a/dlordinal/layers/clm.py +++ b/dlordinal/layers/clm.py @@ -1,3 +1,4 @@ +import warnings from math import sqrt import torch @@ -17,6 +18,8 @@ class CLM(Module): The link function to use. Can be ``'logit'``, ``'probit'`` or ``'cloglog'``. min_distance : float, default=0.0 The minimum distance between thresholds + clip_warning : bool, default=True + Whether to print the clipping value warning or not. Example --------- @@ -41,11 +44,14 @@ class CLM(Module): [0.5734, 0.0062, 0.0423, 0.0392, 0.3389]], grad_fn=) """ - def __init__(self, num_classes, link_function, min_distance=0.0, **kwargs): + def __init__( + self, num_classes, link_function, min_distance=0.0, clip_warning=True, **kwargs + ): super().__init__() self.num_classes = num_classes self.link_function = link_function self.min_distance = min_distance + self.clip_warning = clip_warning self.dist = torch.distributions.Normal(0, 1) self.thresholds_b = torch.nn.Parameter(data=torch.Tensor(1), requires_grad=True) @@ -60,6 +66,8 @@ def __init__(self, num_classes, link_function, min_distance=0.0, **kwargs): sqrt(1.0 / (self.num_classes - 2)), ) + self.clip_warning_shown = False + def _convert_thresholds(self, b, a, min_distance): a = a**2 a = a + min_distance @@ -91,7 +99,17 @@ def _clm(self, projected: torch.Tensor, thresholds: torch.Tensor): 0, 1, ) + z3 = a - b + if torch.any(z3 > 10) or torch.any(z3 < -10): + if self.clip_warning and not self.clip_warning_shown: + warnings.warn( + "The output value of the CLM layer is out of the range [-10, 10]." + " Clipping value prior to applying the link function for numerical" + " stability." + ) + z3 = torch.clip(a - b, -10, 10) + self.clip_warning_shown = True if self.link_function == "probit": a3T = self.dist.cdf(z3) diff --git a/dlordinal/layers/tests/test_clm.py b/dlordinal/layers/tests/test_clm.py index 0266e0d..f179096 100644 --- a/dlordinal/layers/tests/test_clm.py +++ b/dlordinal/layers/tests/test_clm.py @@ -1,5 +1,7 @@ -import torch +import warnings +import pytest +import torch from dlordinal.layers import CLM @@ -61,3 +63,41 @@ def test_clm_cloglog(): assert isinstance(output, torch.Tensor) assert clm.link_function == "cloglog" + + +def test_clm_clip(): + input_shape = 12 + num_classes = 6 + link_function = "cloglog" + min_distance = 0.0 + + clm = CLM( + num_classes=num_classes, + link_function=link_function, + min_distance=min_distance, + clip_warning=True, + ) + input_data = torch.rand(8, input_shape) * 100 + with pytest.warns(Warning, match="Clipping"): + clm(input_data) + + warnings.filterwarnings("error") + clm(input_data) + + clm = CLM( + num_classes=num_classes, + link_function=link_function, + min_distance=min_distance, + clip_warning=False, + ) + clm(input_data) + + clm = CLM( + num_classes=num_classes, + link_function=link_function, + min_distance=min_distance, + clip_warning=True, + ) + input_data = torch.rand(8, input_shape) * 0.1 + clm(input_data) + warnings.resetwarnings()