diff --git a/language_interpolation/networks.py b/language_interpolation/networks.py index 20465a3..631c5ac 100644 --- a/language_interpolation/networks.py +++ b/language_interpolation/networks.py @@ -11,7 +11,11 @@ HighOrderTailFocusNetwork, ) from high_order_layers_torch.positional_embeddings import ClassicSinusoidalEmbedding -from high_order_layers_torch.layers import MaxAbsNormalizationLast, high_order_fc_layers +from high_order_layers_torch.layers import ( + MaxAbsNormalizationLast, + high_order_fc_layers, + MaxCenterNormalizationLast, +) from high_order_layers_torch.networks import initialize_network_polynomial_layers from torchmetrics import Accuracy from torch import Tensor @@ -51,6 +55,8 @@ def forward(self, x): if normalizer == "maxabs": normalization = MaxAbsNormalizationLast + elif normalizer == "maxcenter": + normalization = MaxCenterNormalizationLast elif normalizer == "layer": normalization = LazyLayerNormLastDim elif normalizer == "none": diff --git a/pyproject.toml b/pyproject.toml index ae88fde..a0d728b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ langchain = "^0.0.334" hydra-core = "^1.3.2" tensorboard = "^2.15.1" lion-pytorch = "^0.1.2" -high-order-layers-torch = "^2.4.0" +high-order-layers-torch = "^2.4.1" [tool.poetry.group.dev.dependencies] black = "^23.11.0"