Skip to content

Commit

Permalink
Just some notes
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 31, 2023
1 parent 62330c6 commit cff2172
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion language_interpolation/state_space_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def __init__(self, args: ModelArgs):

self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

# Kernel is ~4 and this does a depthwise
# convolution because groups=k*in_channels (k=1)
self.conv1d = nn.Conv1d(
in_channels=args.d_inner,
out_channels=args.d_inner,
Expand Down Expand Up @@ -185,7 +187,12 @@ class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/m
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)

x = rearrange(x, 'b l d_in -> b d_in l')
x = self.conv1d(x)[:, :, :l]
# Depthwise convolution
# Why do we use a convolution and not just an MLP that operates on the
# channels? Probably because it requires fewer parameters. At any rate,
# This appears to be causal as information isn't shared with the next
# time step (only within a timestep)
x = self.conv1d(x)[:, :, :l] # am I missing something, should always be size l.
x = rearrange(x, 'b d_in l -> b l d_in')

x = F.silu(x)
Expand Down

0 comments on commit cff2172

Please sign in to comment.