Skip to content

Commit

Permalink
Add reshape apply to input projection
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 30, 2023
1 parent 80161b1 commit 62330c6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion language_interpolation/state_space_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/m

(b, l, d) = x.shape

x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
x_and_res = reshape_apply(x, self.in_proj) # shape (b, l, 2 * d_in)
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)

x = rearrange(x, 'b l d_in -> b d_in l')
Expand Down

0 comments on commit 62330c6

Please sign in to comment.