Skip to content

Commit

Permalink
accept image fmaps
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 10, 2023
1 parent 9f712b8 commit 66966be
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soft-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Soft MoE - Pytorch',
author = 'Phil Wang',
Expand Down
12 changes: 11 additions & 1 deletion soft_moe_pytorch/soft_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F
from torch import nn, einsum, Tensor

from einops import rearrange
from einops import rearrange, pack, unpack

# helper functions

Expand Down Expand Up @@ -75,6 +75,12 @@ def forward(self, x):
d - feature dimension
"""

is_image = x.ndim == 4

if is_image:
x = rearrange(x, 'b d h w -> b h w d')
x, ps = pack([x], 'b * d')

# following Algorithm 1, with the normalization they proposed, but with scaling of both (the now popular rmsnorm + gamma)

x = self.norm(x)
Expand Down Expand Up @@ -104,4 +110,8 @@ def forward(self, x):
out = rearrange(out, 'e b s d -> b (e s) d')
out = einsum('b s d, b n s -> b n d', out, combine_weights)

if is_image:
out, = unpack(out, ps, 'b * d')
out = rearrange(out, 'b h w d -> b d h w')

return out

0 comments on commit 66966be

Please sign in to comment.