Skip to content

Commit

Permalink
Working on a second low order mlp using input embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jan 3, 2024
1 parent b43c05a commit 638d05c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/high_order_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run_language_interpolation(cfg: DictConfig):
repeats=cfg.data.repeats,
as_index = False
)
elif cfg.net.model_type in ["mamba"] :
elif cfg.net.model_type in ["mamba","low_order_mlp"] :
datamodule = MambaDataModule(
characters_per_feature=cfg.data.characters_per_feature,
max_features=cfg.data.max_features,
Expand Down
37 changes: 33 additions & 4 deletions language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,37 @@ def __init__(
def select_network(cfg: DictConfig, device: str = None):
normalization = select_normalization(cfg.net.normalize)

if cfg.net.model_type == "high_order_input":
if cfg.net.model_type == "low_order_mlp":
"""
Only the input layer is high order, the rest
of the layers are standard linear+relu and normalization.
"""

layer_list = []
input_layer = torch.nn.Embedding(
num_embeddings=128,
embedding_dim=cfg.net.hidden.width,
device=cfg.accelerator,
)

layer_list.append(input_layer)

if normalization is not None:
layer_list.append(normalization())

lower_layers = LowOrderMLP(
in_width=cfg.net.hidden.width,
out_width=cfg.net.output.width,
hidden_width=cfg.net.hidden.width,
hidden_layers=cfg.net.hidden.layers - 1,
non_linearity=torch.nn.ReLU(),
normalization=normalization,
#device=cfg.accelerator,
)
layer_list.append(lower_layers)

model = torch.nn.Sequential(*layer_list)
elif cfg.net.model_type == "high_order_input":
"""
Only the input layer is high order, the rest
of the layers are standard linear+relu and normalization.
Expand All @@ -769,14 +799,13 @@ def select_network(cfg: DictConfig, device: str = None):
layer_list.append(normalization())

lower_layers = LowOrderMLP(
layer_type=cfg.net.layer_type,
in_width=cfg.net.hidden.width,
out_width=cfg.output.width,
out_width=cfg.net.output.width,
hidden_width=cfg.net.hidden.width,
hidden_layers=cfg.net.hidden.layers - 1,
non_linearity=torch.nn.ReLU(),
normalization=normalization,
device=cfg.accelerator,
#device=cfg.accelerator,
)
layer_list.append(lower_layers)

Expand Down

0 comments on commit 638d05c

Please sign in to comment.