diff --git a/high_order_layers_torch/layers.py b/high_order_layers_torch/layers.py index 2b41f29..c20d737 100644 --- a/high_order_layers_torch/layers.py +++ b/high_order_layers_torch/layers.py @@ -20,12 +20,13 @@ class MaxAbsNormalization(nn.Module): Normalization for the 1D case (MLP) """ - def __init__(self, eps: float = 1e-6): + def __init__(self, eps: float = 1e-6, dim: int = 1): super().__init__() self._eps = eps + self._dim = dim def forward(self, x): - return max_abs_normalization(x, eps=self._eps) + return max_abs_normalization(x, eps=self._eps, dim=self._dim) class MaxCenterNormalization(nn.Module): diff --git a/high_order_layers_torch/utils.py b/high_order_layers_torch/utils.py index 0227f2a..d00ea5f 100644 --- a/high_order_layers_torch/utils.py +++ b/high_order_layers_torch/utils.py @@ -5,17 +5,17 @@ from torch import Tensor -def max_abs(x: Tensor): - return torch.max(x.abs(), dim=1, keepdim=True)[0] +def max_abs(x: Tensor, dim: int=1): + return torch.max(x.abs(), dim=dim, keepdim=True)[0] -def max_abs_normalization(x: Tensor, eps: float = 1e-6): - return x / (max_abs(x) + eps) +def max_abs_normalization(x: Tensor, eps: float = 1e-6, dim:int=1): + return x / (max_abs(x, dim=dim) + eps) -def max_center_normalization(x: Tensor, eps: float = 1e-6): - max_x = torch.max(x, dim=1, keepdim=True)[0] - min_x = torch.min(x, dim=1, keepdim=True)[0] +def max_center_normalization(x: Tensor, eps: float = 1e-6, dim:int=1): + max_x = torch.max(x, dim=dim, keepdim=True)[0] + min_x = torch.min(x, dim=dim, keepdim=True)[0] midrange = 0.5 * (max_x + min_x) mag = max_x - midrange diff --git a/pyproject.toml b/pyproject.toml index b15c96a..505639b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "high-order-layers-torch" -version = "2.1.0" +version = "2.2.0" description = "High order layers in pytorch" authors = ["jloverich "] license = "MIT" diff --git a/tests/test_layers.py b/tests/test_layers.py index 3d0496d..a29e079 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -101,6 +101,14 @@ def test_max_abs_layers(): assert torch.all(torch.eq(ans[1][0], torch.tensor([0.5, 0.0625, 0.0625]))) assert torch.all(torch.eq(ans[1][1], torch.tensor([1, 0.0625, 0.0625]))) + layer = MaxAbsNormalization(eps=0.0, dim=2) + ans = layer(x) + assert torch.all(torch.eq(ans[0][0], torch.tensor([1, 0.5, 0.5]))) + assert torch.all(torch.eq(ans[0][1], torch.tensor([1, 0.25, 0.25]))) + assert torch.all(torch.eq(ans[1][0], torch.tensor([1, 0.125, 0.125]))) + assert torch.all(torch.eq(ans[1][1], torch.tensor([1, 0.0625, 0.0625]))) + + def test_max_center_layers(): x = torch.tensor([[1, 0.5, 0.5], [2, 0.5, 0.5]])