From 638d05c90a29f18c3b4ee5bcb049d507aa4d48f8 Mon Sep 17 00:00:00 2001 From: jloveric Date: Tue, 2 Jan 2024 17:09:08 -0800 Subject: [PATCH] Working on a second low order mlp using input embeddings --- examples/high_order_interpolation.py | 2 +- language_interpolation/networks.py | 37 +++++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/examples/high_order_interpolation.py b/examples/high_order_interpolation.py index 99ab352..4b75caa 100644 --- a/examples/high_order_interpolation.py +++ b/examples/high_order_interpolation.py @@ -74,7 +74,7 @@ def run_language_interpolation(cfg: DictConfig): repeats=cfg.data.repeats, as_index = False ) - elif cfg.net.model_type in ["mamba"] : + elif cfg.net.model_type in ["mamba","low_order_mlp"] : datamodule = MambaDataModule( characters_per_feature=cfg.data.characters_per_feature, max_features=cfg.data.max_features, diff --git a/language_interpolation/networks.py b/language_interpolation/networks.py index 2e5ba67..e0eb8fc 100644 --- a/language_interpolation/networks.py +++ b/language_interpolation/networks.py @@ -749,7 +749,37 @@ def __init__( def select_network(cfg: DictConfig, device: str = None): normalization = select_normalization(cfg.net.normalize) - if cfg.net.model_type == "high_order_input": + if cfg.net.model_type == "low_order_mlp": + """ + Only the input layer is high order, the rest + of the layers are standard linear+relu and normalization. + """ + + layer_list = [] + input_layer = torch.nn.Embedding( + num_embeddings=128, + embedding_dim=cfg.net.hidden.width, + device=cfg.accelerator, + ) + + layer_list.append(input_layer) + + if normalization is not None: + layer_list.append(normalization()) + + lower_layers = LowOrderMLP( + in_width=cfg.net.hidden.width, + out_width=cfg.net.output.width, + hidden_width=cfg.net.hidden.width, + hidden_layers=cfg.net.hidden.layers - 1, + non_linearity=torch.nn.ReLU(), + normalization=normalization, + #device=cfg.accelerator, + ) + layer_list.append(lower_layers) + + model = torch.nn.Sequential(*layer_list) + elif cfg.net.model_type == "high_order_input": """ Only the input layer is high order, the rest of the layers are standard linear+relu and normalization. @@ -769,14 +799,13 @@ def select_network(cfg: DictConfig, device: str = None): layer_list.append(normalization()) lower_layers = LowOrderMLP( - layer_type=cfg.net.layer_type, in_width=cfg.net.hidden.width, - out_width=cfg.output.width, + out_width=cfg.net.output.width, hidden_width=cfg.net.hidden.width, hidden_layers=cfg.net.hidden.layers - 1, non_linearity=torch.nn.ReLU(), normalization=normalization, - device=cfg.accelerator, + #device=cfg.accelerator, ) layer_list.append(lower_layers)