Skip to content

Commit

Permalink
Update test
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 5, 2024
1 parent 87e2c83 commit cdea5a5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
2 changes: 1 addition & 1 deletion language_interpolation/dual_convolutional_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def forward(self, x: Tensor):
if val.shape[1] % 2 == 1:
# Add padding to the end, hope this doesn't bust anything
val = torch.cat(
[val, torch.zeros(val.shape[0], 1, val.shape[2])], dim=1, device=self.device
[val, torch.zeros(val.shape[0], 1, val.shape[2], device=self.device)], dim=1
)

valshape = val.shape
Expand Down
33 changes: 29 additions & 4 deletions tests/test_dual_convolution.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,36 @@
import pytest

from language_interpolation.dual_convolutional_network import DualConvolutionalLayer
from language_interpolation.dual_convolutional_network import (
DualConvolutionalLayer,
DualConvolutionNetwork,
)
import torch


def test_dual_convolution():
net = DualConvolutionalLayer(n=3, in_width=1, out_width=10, hidden_layers=2, hidden_width=10, in_segments=128, segments=5)
x = torch.rand(10, 15, 1) # character level
net = DualConvolutionalLayer(
n=3,
in_width=1,
out_width=10,
hidden_layers=2,
hidden_width=10,
in_segments=128,
segments=5,
)
x = torch.rand(10, 15, 1) # character level
res = net(x)
print('res', res.shape)
print("res", res.shape)

net = DualConvolutionNetwork(
n=3,
in_width=1,
out_width=10,
embedding_dimension=5,
hidden_layers=2,
hidden_width=10,
in_segments=128,
segments=5,
device="cpu",
)
res = net(x)
print('res', res)

0 comments on commit cdea5a5

Please sign in to comment.