Skip to content

Commit

Permalink
Merge pull request #70 from RafaAyGar/main
Browse files Browse the repository at this point in the history
[BUG] Fixed CLM numerical instability
  • Loading branch information
victormvy authored Jun 13, 2024
2 parents 1527b2b + 3e26aa4 commit 406f9c1
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
20 changes: 19 additions & 1 deletion dlordinal/layers/clm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from math import sqrt

import torch
Expand All @@ -17,6 +18,8 @@ class CLM(Module):
The link function to use. Can be ``'logit'``, ``'probit'`` or ``'cloglog'``.
min_distance : float, default=0.0
The minimum distance between thresholds
clip_warning : bool, default=True
Whether to print the clipping value warning or not.
Example
---------
Expand All @@ -41,11 +44,14 @@ class CLM(Module):
[0.5734, 0.0062, 0.0423, 0.0392, 0.3389]], grad_fn=<CopySlices>)
"""

def __init__(self, num_classes, link_function, min_distance=0.0, **kwargs):
def __init__(
self, num_classes, link_function, min_distance=0.0, clip_warning=True, **kwargs
):
super().__init__()
self.num_classes = num_classes
self.link_function = link_function
self.min_distance = min_distance
self.clip_warning = clip_warning
self.dist = torch.distributions.Normal(0, 1)

self.thresholds_b = torch.nn.Parameter(data=torch.Tensor(1), requires_grad=True)
Expand All @@ -60,6 +66,8 @@ def __init__(self, num_classes, link_function, min_distance=0.0, **kwargs):
sqrt(1.0 / (self.num_classes - 2)),
)

self.clip_warning_shown = False

def _convert_thresholds(self, b, a, min_distance):
a = a**2
a = a + min_distance
Expand Down Expand Up @@ -91,7 +99,17 @@ def _clm(self, projected: torch.Tensor, thresholds: torch.Tensor):
0,
1,
)

z3 = a - b
if torch.any(z3 > 10) or torch.any(z3 < -10):
if self.clip_warning and not self.clip_warning_shown:
warnings.warn(
"The output value of the CLM layer is out of the range [-10, 10]."
" Clipping value prior to applying the link function for numerical"
" stability."
)
z3 = torch.clip(a - b, -10, 10)
self.clip_warning_shown = True

if self.link_function == "probit":
a3T = self.dist.cdf(z3)
Expand Down
42 changes: 41 additions & 1 deletion dlordinal/layers/tests/test_clm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import warnings

import pytest
import torch
from dlordinal.layers import CLM


Expand Down Expand Up @@ -61,3 +63,41 @@ def test_clm_cloglog():

assert isinstance(output, torch.Tensor)
assert clm.link_function == "cloglog"


def test_clm_clip():
input_shape = 12
num_classes = 6
link_function = "cloglog"
min_distance = 0.0

clm = CLM(
num_classes=num_classes,
link_function=link_function,
min_distance=min_distance,
clip_warning=True,
)
input_data = torch.rand(8, input_shape) * 100
with pytest.warns(Warning, match="Clipping"):
clm(input_data)

warnings.filterwarnings("error")
clm(input_data)

clm = CLM(
num_classes=num_classes,
link_function=link_function,
min_distance=min_distance,
clip_warning=False,
)
clm(input_data)

clm = CLM(
num_classes=num_classes,
link_function=link_function,
min_distance=min_distance,
clip_warning=True,
)
input_data = torch.rand(8, input_shape) * 0.1
clm(input_data)
warnings.resetwarnings()

0 comments on commit 406f9c1

Please sign in to comment.