Skip to content

Commit

Permalink
Experimenting with a new network
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 4, 2024
1 parent d02d866 commit 530a262
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
57 changes: 54 additions & 3 deletions language_interpolation/dual_convolutional_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,62 @@
Convolutional network that also shares
depth wise convolutions
"""

import torch
from high_order_layers_torch.networks import HighOrderMLP
from torch import Tensor


class DualConvolutionalNetwork(torch.nn.Module):
def __init__(self):
def __init__(
self,
n: str,
out_width: int,
hidden_layers: int,
hidden_width: int,
in_segments: int = None,
segments: int = None,
device: str = "cpu",
):
super().__init__()

def other():
pass
self.input_layer = HighOrderMLP(
layer_type="continuous",
n=n,
in_width=2,
in_segments=in_segments,
out_width=out_width,
hidden_layers=0,
hidden_width=1,
device=device,
out_segments=segments,
hidden_segments=segments,
)
self.equal_layers = HighOrderMLP(
layer_type="continuous",
n=n,
in_width=out_width,
out_width=out_width,
hidden_layers=hidden_layers,
hidden_width=hidden_width,
device=device,
in_segments=segments,
out_segments=segments,
hidden_segments=segments,
)

def forward(self, x: Tensor):
"""
x has shape [B, L (sequence length), dimension]
"""

val = self.input_layer(x)
while val.shape[1] > 1:

if val.shape[1] % 2 == 1:
# Add padding to the end, hope this doesn't bust anything
val = torch.cat([val, torch.zeros(1, val.shape[1], val.shape[2])])

val = self.equal_layers(val)
return val

11 changes: 11 additions & 0 deletions tests/test_dual_convolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

from language_interpolation.dual_convolutional_network import DualConvolutionalNetwork
import torch


def test_dual_convolution():
net = DualConvolutionalNetwork(n=3, out_width=10, hidden_layers=2, hidden_width=10, in_segments=128, segments=5)
x = torch.rand(10, 20, 15)
res = net(x)
print('res', res)

0 comments on commit 530a262

Please sign in to comment.