From cbbfd77b7e50fe3ed20721fdee1f65b9c9495fd7 Mon Sep 17 00:00:00 2001 From: jloveric Date: Mon, 4 Dec 2023 20:09:25 -0800 Subject: [PATCH] Add dimension to my normalization --- high_order_layers_torch/layers.py | 5 +++-- high_order_layers_torch/utils.py | 14 +++++++------- pyproject.toml | 2 +- tests/test_layers.py | 8 ++++++++ 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/high_order_layers_torch/layers.py b/high_order_layers_torch/layers.py index 2b41f29..c20d737 100644 --- a/high_order_layers_torch/layers.py +++ b/high_order_layers_torch/layers.py @@ -20,12 +20,13 @@ class MaxAbsNormalization(nn.Module): Normalization for the 1D case (MLP) """ - def __init__(self, eps: float = 1e-6): + def __init__(self, eps: float = 1e-6, dim: int = 1): super().__init__() self._eps = eps + self._dim = dim def forward(self, x): - return max_abs_normalization(x, eps=self._eps) + return max_abs_normalization(x, eps=self._eps, dim=self._dim) class MaxCenterNormalization(nn.Module): diff --git a/high_order_layers_torch/utils.py b/high_order_layers_torch/utils.py index 0227f2a..d00ea5f 100644 --- a/high_order_layers_torch/utils.py +++ b/high_order_layers_torch/utils.py @@ -5,17 +5,17 @@ from torch import Tensor -def max_abs(x: Tensor): - return torch.max(x.abs(), dim=1, keepdim=True)[0] +def max_abs(x: Tensor, dim: int=1): + return torch.max(x.abs(), dim=dim, keepdim=True)[0] -def max_abs_normalization(x: Tensor, eps: float = 1e-6): - return x / (max_abs(x) + eps) +def max_abs_normalization(x: Tensor, eps: float = 1e-6, dim:int=1): + return x / (max_abs(x, dim=dim) + eps) -def max_center_normalization(x: Tensor, eps: float = 1e-6): - max_x = torch.max(x, dim=1, keepdim=True)[0] - min_x = torch.min(x, dim=1, keepdim=True)[0] +def max_center_normalization(x: Tensor, eps: float = 1e-6, dim:int=1): + max_x = torch.max(x, dim=dim, keepdim=True)[0] + min_x = torch.min(x, dim=dim, keepdim=True)[0] midrange = 0.5 * (max_x + min_x) mag = max_x - midrange diff --git a/pyproject.toml b/pyproject.toml index b15c96a..505639b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "high-order-layers-torch" -version = "2.1.0" +version = "2.2.0" description = "High order layers in pytorch" authors = ["jloverich "] license = "MIT" diff --git a/tests/test_layers.py b/tests/test_layers.py index 3d0496d..a29e079 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -101,6 +101,14 @@ def test_max_abs_layers(): assert torch.all(torch.eq(ans[1][0], torch.tensor([0.5, 0.0625, 0.0625]))) assert torch.all(torch.eq(ans[1][1], torch.tensor([1, 0.0625, 0.0625]))) + layer = MaxAbsNormalization(eps=0.0, dim=2) + ans = layer(x) + assert torch.all(torch.eq(ans[0][0], torch.tensor([1, 0.5, 0.5]))) + assert torch.all(torch.eq(ans[0][1], torch.tensor([1, 0.25, 0.25]))) + assert torch.all(torch.eq(ans[1][0], torch.tensor([1, 0.125, 0.125]))) + assert torch.all(torch.eq(ans[1][1], torch.tensor([1, 0.0625, 0.0625]))) + + def test_max_center_layers(): x = torch.tensor([[1, 0.5, 0.5], [2, 0.5, 0.5]])