From 62330c6172802c0fd124dcf608aac64d1ab5a3ea Mon Sep 17 00:00:00 2001 From: jloveric Date: Sat, 30 Dec 2023 09:51:42 -0800 Subject: [PATCH] Add reshape apply to input projection --- language_interpolation/state_space_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/language_interpolation/state_space_network.py b/language_interpolation/state_space_network.py index 5d0528f..d0929f7 100644 --- a/language_interpolation/state_space_network.py +++ b/language_interpolation/state_space_network.py @@ -181,7 +181,7 @@ class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/m (b, l, d) = x.shape - x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in) + x_and_res = reshape_apply(x, self.in_proj) # shape (b, l, 2 * d_in) (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1) x = rearrange(x, 'b l d_in -> b d_in l')