diff --git a/language_interpolation/state_space_network.py b/language_interpolation/state_space_network.py index 4608918..5d0528f 100644 --- a/language_interpolation/state_space_network.py +++ b/language_interpolation/state_space_network.py @@ -31,7 +31,7 @@ from dataclasses import dataclass from einops import rearrange, repeat, einsum from typing import Union - +from language_interpolation.utils import reshape_apply # TODO: I don't like this approach to setting data inputs @dataclass @@ -194,8 +194,8 @@ class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/m y = y * F.silu(res) - output = self.out_proj(y) - + #output = self.out_proj(y) + output = reshape_apply(y, self.out_proj) return output diff --git a/language_interpolation/utils.py b/language_interpolation/utils.py index a427c15..67d94c2 100644 --- a/language_interpolation/utils.py +++ b/language_interpolation/utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Any from high_order_layers_torch.layers import * from pytorch_lightning import Callback @@ -17,6 +17,23 @@ logger = logging.getLogger(__name__) +def reshape_apply(x : Tensor, layer: Any) : + """ + TODO: Move this to high order layers torch + Linear layer works on arbitrary shaped tensors, but the + High Order Layers do not, so this just solves it for the second + case, but also works for the first + Args: + x : The tensor that needs to be reshaped + layer: The layer to apply to + Returns: + output + """ + shape = x.shape + last = shape[-1] + first = torch.prod(torch.tensor(shape[:-1])) + flatout : Tensor = layer(x.view(first, last)) + return flatout.view(*shape[:-1],-1) def create_gutenberg_cache(parent_directory: str): """