Skip to content

Commit

Permalink
Adding reshape_apply to work with different models
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 30, 2023
1 parent 2dfbcec commit 80161b1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
6 changes: 3 additions & 3 deletions language_interpolation/state_space_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dataclasses import dataclass
from einops import rearrange, repeat, einsum
from typing import Union

from language_interpolation.utils import reshape_apply

# TODO: I don't like this approach to setting data inputs
@dataclass
Expand Down Expand Up @@ -194,8 +194,8 @@ class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/m

y = y * F.silu(res)

output = self.out_proj(y)

#output = self.out_proj(y)
output = reshape_apply(y, self.out_proj)
return output


Expand Down
19 changes: 18 additions & 1 deletion language_interpolation/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Any

from high_order_layers_torch.layers import *
from pytorch_lightning import Callback
Expand All @@ -17,6 +17,23 @@

logger = logging.getLogger(__name__)

def reshape_apply(x : Tensor, layer: Any) :
"""
TODO: Move this to high order layers torch
Linear layer works on arbitrary shaped tensors, but the
High Order Layers do not, so this just solves it for the second
case, but also works for the first
Args:
x : The tensor that needs to be reshaped
layer: The layer to apply to
Returns:
output
"""
shape = x.shape
last = shape[-1]
first = torch.prod(torch.tensor(shape[:-1]))
flatout : Tensor = layer(x.view(first, last))
return flatout.view(*shape[:-1],-1)

def create_gutenberg_cache(parent_directory: str):
"""
Expand Down

0 comments on commit 80161b1

Please sign in to comment.