Skip to content

Commit

Permalink
Removing model dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jan 1, 2024
1 parent f531fc9 commit f9b2fc8
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 114 deletions.
24 changes: 20 additions & 4 deletions language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,21 @@ def eval_step(self, batch: Tensor, name: str):
self.log(f"{name}_acc", accuracy, prog_bar=True)
return loss


class MambaClassificationMixin:
def eval_step(self, batch: Tensor, name: str):
x, y, idx = batch
y_hat = self(x)
loss = self.loss(y_hat.reshape(y.shape[0]*y.shape[1],-1), y.flatten())
loss = self.loss(y_hat.reshape(y.shape[0] * y.shape[1], -1), y.flatten())

diff = torch.argmax(y_hat, dim=2, keepdim=True) - y
diff = torch.argmax(y_hat, dim=2, keepdim=True) - y
accuracy = torch.where(diff == 0, 1, 0).sum() / torch.numel(diff)

self.log(f"{name}_loss", loss, prog_bar=True)
self.log(f"{name}_acc", accuracy, prog_bar=True)
return loss


class RegressionMixin:
def eval_step(self, batch: Tensor, name: str):
x, y, idx = batch
Expand Down Expand Up @@ -873,7 +875,18 @@ def select_network(cfg: DictConfig, device: str = None):
conv_bias=cfg.net.conv_bias,
bias=cfg.net.bias,
)
model = Mamba(args=model_args)
model = Mamba(
d_model=cfg.net.d_model,
n_layer=cfg.net.n_layer,
vocab_size=cfg.net.vocab_size,
d_state=cfg.net.d_state,
expand=cfg.net.expand,
dt_rank=cfg.net.dt_rank,
d_conv=cfg.net.d_conv,
pad_vocab_size_multiple=cfg.net.pad_vocab_size_multiple,
conv_bias=cfg.net.conv_bias,
bias=cfg.net.bias,
)
else:
raise ValueError(
f"Unrecognized model_type {cfg.model_type} should be high_order, high_order_input or high_order_conv!"
Expand Down Expand Up @@ -907,7 +920,10 @@ def __init__(self, cfg: DictConfig):
self.loss = torch.nn.CrossEntropyLoss()
self.accuracy = Accuracy(top_k=1, task="multiclass", num_classes=128)

class MambaASCIIPredictionNet(MambaClassificationMixin, PredictionNetMixin, LightningModule):

class MambaASCIIPredictionNet(
MambaClassificationMixin, PredictionNetMixin, LightningModule
):
def __init__(self, cfg: DictConfig):
super().__init__()
self.save_hyperparameters(cfg)
Expand Down
Loading

0 comments on commit f9b2fc8

Please sign in to comment.