diff --git a/language_interpolation/dual_convolutional_network.py b/language_interpolation/dual_convolutional_network.py index 7176d04..0ac40c2 100644 --- a/language_interpolation/dual_convolutional_network.py +++ b/language_interpolation/dual_convolutional_network.py @@ -69,7 +69,7 @@ def forward(self, x: Tensor): if val.shape[1] % 2 == 1: # Add padding to the end, hope this doesn't bust anything val = torch.cat( - [val, torch.zeros(val.shape[0], 1, val.shape[2])], dim=1, device=self.device + [val, torch.zeros(val.shape[0], 1, val.shape[2], device=self.device)], dim=1 ) valshape = val.shape diff --git a/tests/test_dual_convolution.py b/tests/test_dual_convolution.py index 3ff7917..5783df0 100644 --- a/tests/test_dual_convolution.py +++ b/tests/test_dual_convolution.py @@ -1,11 +1,36 @@ import pytest -from language_interpolation.dual_convolutional_network import DualConvolutionalLayer +from language_interpolation.dual_convolutional_network import ( + DualConvolutionalLayer, + DualConvolutionNetwork, +) import torch def test_dual_convolution(): - net = DualConvolutionalLayer(n=3, in_width=1, out_width=10, hidden_layers=2, hidden_width=10, in_segments=128, segments=5) - x = torch.rand(10, 15, 1) # character level + net = DualConvolutionalLayer( + n=3, + in_width=1, + out_width=10, + hidden_layers=2, + hidden_width=10, + in_segments=128, + segments=5, + ) + x = torch.rand(10, 15, 1) # character level res = net(x) - print('res', res.shape) \ No newline at end of file + print("res", res.shape) + + net = DualConvolutionNetwork( + n=3, + in_width=1, + out_width=10, + embedding_dimension=5, + hidden_layers=2, + hidden_width=10, + in_segments=128, + segments=5, + device="cpu", + ) + res = net(x) + print('res', res)