From 00a3cbe8b808ef4d0340cd4effbf9e0378b74102 Mon Sep 17 00:00:00 2001 From: jloveric Date: Thu, 28 Dec 2023 09:16:47 -0800 Subject: [PATCH] Changes to get mamba working --- examples/high_order_interpolation.py | 1 + .../lightning_datamodule.py | 8 +++++++ language_interpolation/single_text_dataset.py | 23 +++++++++++++++---- language_interpolation/state_space_network.py | 8 ++++--- 4 files changed, 32 insertions(+), 8 deletions(-) diff --git a/examples/high_order_interpolation.py b/examples/high_order_interpolation.py index c4281c3..cf41cb3 100644 --- a/examples/high_order_interpolation.py +++ b/examples/high_order_interpolation.py @@ -73,6 +73,7 @@ def run_language_interpolation(cfg: DictConfig): test_filenames=cfg.data.test.filenames, max_size=cfg.data.max_size, repeats=cfg.data.repeats, + as_index = True if cfg.net.model_type=="mamba" else False ) else: datamodule = GutenbergDataModule( diff --git a/language_interpolation/lightning_datamodule.py b/language_interpolation/lightning_datamodule.py index 70897c2..02728e4 100644 --- a/language_interpolation/lightning_datamodule.py +++ b/language_interpolation/lightning_datamodule.py @@ -274,6 +274,7 @@ def __init__( add_channel_dimension: bool = False, transforms: Callable[[Tensor], Tensor] = None, repeats: int = 1, + as_index: bool = False, ): """ Data module for this type of transformer @@ -302,6 +303,7 @@ def __init__( self._add_channel_dimension = add_channel_dimension self._transforms = transforms self._repeats = repeats + self._as_index = as_index def normalize(self, data): return (data - 64 + 0.5) / 64.0 @@ -316,6 +318,12 @@ def collate_fn(self, batch) -> tuple[Tensor, Tensor, list[int]]: final_targets = torch.stack([sample[0][this_size][0] for sample in batch]) final_indexes = [sample[1] for sample in batch] + if self._as_index is True: + return ( + final_features, + final_targets, + final_indexes, + ) return self.normalize(final_features), final_targets, final_indexes diff --git a/language_interpolation/single_text_dataset.py b/language_interpolation/single_text_dataset.py index 63136c8..2cd396e 100644 --- a/language_interpolation/single_text_dataset.py +++ b/language_interpolation/single_text_dataset.py @@ -407,6 +407,7 @@ def __init__( transforms: Callable[[Tensor], Tensor] = None, transformer: bool = False, embedding_size: int = None, + as_index: bool=False ): """ Args : @@ -421,6 +422,7 @@ def __init__( add_channel_dimension: For convnets we need to add a channel dimension to the data transformer: Whether it should be formatted for a (high order) transformer or not embedding_size: Size of the embedding if a transformer is being used. + as_index: Inputs should be indexes instead of floats """ list_features, list_targets = dataset_sequential( @@ -447,6 +449,7 @@ def __init__( self.valid_ids = list(range(0, len(list_features))) self._transformer = transformer self._embedding_size = embedding_size + self._as_index = as_index def __len__(self): return len(self.valid_ids) @@ -483,11 +486,18 @@ def group(self, idx) -> Tensor: if self.transforms is not None: inputs = self.transforms(inputs) - return ( - self.normalize(inputs).reshape(inputs.shape[0], -1, self._embedding_size), - self.output[index].reshape(self.output.shape[0], -1, self._embedding_size), - idx, - ) + if self._as_index is False: + return ( + self.normalize(inputs).reshape(inputs.shape[0], -1, self._embedding_size), + self.output[index].reshape(self.output.shape[0], -1, self._embedding_size), + idx, + ) + else : + return ( + inputs.reshape(inputs.shape[0], -1, self._embedding_size), + self.output[index].reshape(self.output.shape[0], -1, self._embedding_size), + idx, + ) def __getitem__(self, idx) -> Tensor: if self._transformer is True: @@ -524,6 +534,7 @@ def __init__( transforms: Callable[[Tensor], Tensor] = None, embedding_size: int = None, repeats: int = 1, + as_index: bool=True, ): """ Args : @@ -538,6 +549,7 @@ def __init__( processed. add_channel_dimension: For convnets we need to add a channel dimension to the data embedding_size: Size of the embedding if a transformer is being used. + as_index: inputs should be indexes not floats """ list_features, list_targets = dataset_sequential( @@ -573,6 +585,7 @@ def __init__( self.data_size = len(self.inputs) - self._max_characters self._repeats = repeats + self._as_index = as_index def __len__(self): return int((len(self.inputs) - self._max_characters) * self._repeats) diff --git a/language_interpolation/state_space_network.py b/language_interpolation/state_space_network.py index 19d6c87..c5c75d1 100644 --- a/language_interpolation/state_space_network.py +++ b/language_interpolation/state_space_network.py @@ -86,9 +86,9 @@ class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ss """ print('input_ids', input_ids) - - x = self.embedding(input_ids) - + reshaped = input_ids.reshape(input_ids.shape[0], input_ids.shape[1]*input_ids.shape[2]) + x = self.embedding(reshaped) + print('x.shape after', x.shape) for layer in self.layers: x = layer(x) @@ -127,6 +127,7 @@ def forward(self, x): [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> .... """ + print('residual block', x.shape) output = self.mixer(self.norm(x)) + x return output @@ -176,6 +177,7 @@ class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/m mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 """ + (b, l, d) = x.shape x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)