From 0e6c86cd3e6554389edeac28b2ad3f6d22d5130d Mon Sep 17 00:00:00 2001 From: jloveric Date: Thu, 28 Dec 2023 18:21:40 -0800 Subject: [PATCH] Hack to get mamba running, but I'm gonna need a new loader --- language_interpolation/networks.py | 1 + language_interpolation/state_space_network.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/language_interpolation/networks.py b/language_interpolation/networks.py index 95706ba..27b6c64 100644 --- a/language_interpolation/networks.py +++ b/language_interpolation/networks.py @@ -121,6 +121,7 @@ class ClassificationMixin: def eval_step(self, batch: Tensor, name: str): x, y, idx = batch y_hat = self(x) + print('y_hat.shape',y_hat.shape, 'y shape', y.shape) loss = self.loss(y_hat, y.flatten()) diff = torch.argmax(y_hat, dim=1) - y.flatten() diff --git a/language_interpolation/state_space_network.py b/language_interpolation/state_space_network.py index c5c75d1..db65b77 100644 --- a/language_interpolation/state_space_network.py +++ b/language_interpolation/state_space_network.py @@ -85,7 +85,7 @@ def forward(self, input_ids): class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173 """ - print('input_ids', input_ids) + print('input_ids.shape', input_ids.shape) reshaped = input_ids.reshape(input_ids.shape[0], input_ids.shape[1]*input_ids.shape[2]) x = self.embedding(reshaped) print('x.shape after', x.shape) @@ -94,8 +94,9 @@ class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ss x = self.norm_f(x) logits = self.lm_head(x) + reduced = logits[:,-1,:].reshape(logits.shape[0], logits.shape[2]) - return logits + return reduced #logits class ResidualBlock(nn.Module):