Skip to content

Commit

Permalink
Changes to get mamba working
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 28, 2023
1 parent f606c89 commit 00a3cbe
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 8 deletions.
1 change: 1 addition & 0 deletions examples/high_order_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def run_language_interpolation(cfg: DictConfig):
test_filenames=cfg.data.test.filenames,
max_size=cfg.data.max_size,
repeats=cfg.data.repeats,
as_index = True if cfg.net.model_type=="mamba" else False
)
else:
datamodule = GutenbergDataModule(
Expand Down
8 changes: 8 additions & 0 deletions language_interpolation/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(
add_channel_dimension: bool = False,
transforms: Callable[[Tensor], Tensor] = None,
repeats: int = 1,
as_index: bool = False,
):
"""
Data module for this type of transformer
Expand Down Expand Up @@ -302,6 +303,7 @@ def __init__(
self._add_channel_dimension = add_channel_dimension
self._transforms = transforms
self._repeats = repeats
self._as_index = as_index

def normalize(self, data):
return (data - 64 + 0.5) / 64.0
Expand All @@ -316,6 +318,12 @@ def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]:
final_targets = torch.stack([sample[0][this_size][0] for sample in batch])

final_indexes = [sample[1] for sample in batch]
if self._as_index is True:
return (
final_features,
final_targets,
final_indexes,
)

return self.normalize(final_features), final_targets, final_indexes

Expand Down
23 changes: 18 additions & 5 deletions language_interpolation/single_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def __init__(
transforms: Callable[[Tensor], Tensor] = None,
transformer: bool = False,
embedding_size: int = None,
as_index: bool=False
):
"""
Args :
Expand All @@ -421,6 +422,7 @@ def __init__(
add_channel_dimension: For convnets we need to add a channel dimension to the data
transformer: Whether it should be formatted for a (high order) transformer or not
embedding_size: Size of the embedding if a transformer is being used.
as_index: Inputs should be indexes instead of floats
"""

list_features, list_targets = dataset_sequential(
Expand All @@ -447,6 +449,7 @@ def __init__(
self.valid_ids = list(range(0, len(list_features)))
self._transformer = transformer
self._embedding_size = embedding_size
self._as_index = as_index

def __len__(self):
return len(self.valid_ids)
Expand Down Expand Up @@ -483,11 +486,18 @@ def group(self, idx) -> Tensor:
if self.transforms is not None:
inputs = self.transforms(inputs)

return (
self.normalize(inputs).reshape(inputs.shape[0], -1, self._embedding_size),
self.output[index].reshape(self.output.shape[0], -1, self._embedding_size),
idx,
)
if self._as_index is False:
return (
self.normalize(inputs).reshape(inputs.shape[0], -1, self._embedding_size),
self.output[index].reshape(self.output.shape[0], -1, self._embedding_size),
idx,
)
else :
return (
inputs.reshape(inputs.shape[0], -1, self._embedding_size),
self.output[index].reshape(self.output.shape[0], -1, self._embedding_size),
idx,
)

def __getitem__(self, idx) -> Tensor:
if self._transformer is True:
Expand Down Expand Up @@ -524,6 +534,7 @@ def __init__(
transforms: Callable[[Tensor], Tensor] = None,
embedding_size: int = None,
repeats: int = 1,
as_index: bool=True,
):
"""
Args :
Expand All @@ -538,6 +549,7 @@ def __init__(
processed.
add_channel_dimension: For convnets we need to add a channel dimension to the data
embedding_size: Size of the embedding if a transformer is being used.
as_index: inputs should be indexes not floats
"""

list_features, list_targets = dataset_sequential(
Expand Down Expand Up @@ -573,6 +585,7 @@ def __init__(

self.data_size = len(self.inputs) - self._max_characters
self._repeats = repeats
self._as_index = as_index

def __len__(self):
return int((len(self.inputs) - self._max_characters) * self._repeats)
Expand Down
8 changes: 5 additions & 3 deletions language_interpolation/state_space_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ss
"""
print('input_ids', input_ids)

x = self.embedding(input_ids)

reshaped = input_ids.reshape(input_ids.shape[0], input_ids.shape[1]*input_ids.shape[2])
x = self.embedding(reshaped)
print('x.shape after', x.shape)
for layer in self.layers:
x = layer(x)

Expand Down Expand Up @@ -127,6 +127,7 @@ def forward(self, x):
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
"""
print('residual block', x.shape)
output = self.mixer(self.norm(x)) + x

return output
Expand Down Expand Up @@ -176,6 +177,7 @@ class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/m
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""

(b, l, d) = x.shape

x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
Expand Down

0 comments on commit 00a3cbe

Please sign in to comment.