Skip to content

Commit

Permalink
Add high order definition to yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jan 2, 2024
1 parent 005142a commit 110235b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
6 changes: 6 additions & 0 deletions config/net/mamba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ conv_bias: True
bias: False

model_type: mamba

# High order networks
layer_type : continuous
n : 3
segments: 2
hidden_layers: 0
4 changes: 4 additions & 0 deletions language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,10 @@ def select_network(cfg: DictConfig, device: str = None):
pad_vocab_size_multiple=cfg.net.pad_vocab_size_multiple,
conv_bias=cfg.net.conv_bias,
bias=cfg.net.bias,
layer_type=cfg.layer_type,
n=cfg.net.n,
segments=cfg.net.segments,
hidden_layers=cfg.net.hidden_layers,
)
else:
raise ValueError(
Expand Down
16 changes: 16 additions & 0 deletions language_interpolation/state_space_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(
expand: int = 2,
dt_rank: Union[str, int] = "auto",
pad_vocab_size_multiple: int = 8,
layer_type: str = "linear", # Regular Linear layer
n: int = 2,
segments: int = 2,
hidden_layers: int = 0,
):
"""Full Mamba model."""
super().__init__()
Expand All @@ -71,6 +75,10 @@ def __init__(
dt_rank=dt_rank,
conv_bias=conv_bias,
bias=bias,
layer_type=layer_type,
segments=segments,
n=n,
hidden_layers=hidden_layers
)
for _ in range(n_layer)
]
Expand Down Expand Up @@ -122,6 +130,10 @@ def __init__(
dt_rank: int,
conv_bias: bool,
bias: bool,
layer_type: str = "continuous", # Regular Linear layer
n: int = 2,
segments: int = 2,
hidden_layers: int = 0,
):
"""Simple block wrapping Mamba block with normalization and residual connection."""
super().__init__()
Expand All @@ -133,6 +145,10 @@ def __init__(
dt_rank=dt_rank,
conv_bias=conv_bias,
bias=bias,
layer_type=layer_type,
segments=segments,
n=n,
hidden_layers=hidden_layers
)
self.norm = RMSNorm(d_model)

Expand Down

0 comments on commit 110235b

Please sign in to comment.