Skip to content

Commit

Permalink
Add dimension to my normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 5, 2023
1 parent 0ca5d9c commit cbbfd77
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
5 changes: 3 additions & 2 deletions high_order_layers_torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ class MaxAbsNormalization(nn.Module):
Normalization for the 1D case (MLP)
"""

def __init__(self, eps: float = 1e-6):
def __init__(self, eps: float = 1e-6, dim: int = 1):
super().__init__()
self._eps = eps
self._dim = dim

def forward(self, x):
return max_abs_normalization(x, eps=self._eps)
return max_abs_normalization(x, eps=self._eps, dim=self._dim)


class MaxCenterNormalization(nn.Module):
Expand Down
14 changes: 7 additions & 7 deletions high_order_layers_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from torch import Tensor


def max_abs(x: Tensor):
return torch.max(x.abs(), dim=1, keepdim=True)[0]
def max_abs(x: Tensor, dim: int=1):
return torch.max(x.abs(), dim=dim, keepdim=True)[0]


def max_abs_normalization(x: Tensor, eps: float = 1e-6):
return x / (max_abs(x) + eps)
def max_abs_normalization(x: Tensor, eps: float = 1e-6, dim:int=1):
return x / (max_abs(x, dim=dim) + eps)


def max_center_normalization(x: Tensor, eps: float = 1e-6):
max_x = torch.max(x, dim=1, keepdim=True)[0]
min_x = torch.min(x, dim=1, keepdim=True)[0]
def max_center_normalization(x: Tensor, eps: float = 1e-6, dim:int=1):
max_x = torch.max(x, dim=dim, keepdim=True)[0]
min_x = torch.min(x, dim=dim, keepdim=True)[0]

midrange = 0.5 * (max_x + min_x)
mag = max_x - midrange
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "high-order-layers-torch"
version = "2.1.0"
version = "2.2.0"
description = "High order layers in pytorch"
authors = ["jloverich <john.loverich@gmail.com>"]
license = "MIT"
Expand Down
8 changes: 8 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def test_max_abs_layers():
assert torch.all(torch.eq(ans[1][0], torch.tensor([0.5, 0.0625, 0.0625])))
assert torch.all(torch.eq(ans[1][1], torch.tensor([1, 0.0625, 0.0625])))

layer = MaxAbsNormalization(eps=0.0, dim=2)
ans = layer(x)
assert torch.all(torch.eq(ans[0][0], torch.tensor([1, 0.5, 0.5])))
assert torch.all(torch.eq(ans[0][1], torch.tensor([1, 0.25, 0.25])))
assert torch.all(torch.eq(ans[1][0], torch.tensor([1, 0.125, 0.125])))
assert torch.all(torch.eq(ans[1][1], torch.tensor([1, 0.0625, 0.0625])))



def test_max_center_layers():
x = torch.tensor([[1, 0.5, 0.5], [2, 0.5, 0.5]])
Expand Down

0 comments on commit cbbfd77

Please sign in to comment.