From d20253d55adb7c209cab9583bd77a058168d8224 Mon Sep 17 00:00:00 2001 From: John Loverich Date: Sat, 15 Jun 2024 16:40:50 -0700 Subject: [PATCH] More simplification, fix sign error --- high_order_layers_torch/Basis.py | 31 ++++++------------- high_order_layers_torch/LagrangePolynomial.py | 2 +- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/high_order_layers_torch/Basis.py b/high_order_layers_torch/Basis.py index 74974fb..6346e17 100644 --- a/high_order_layers_torch/Basis.py +++ b/high_order_layers_torch/Basis.py @@ -443,9 +443,9 @@ def interpolate(self, x: Tensor, w: Tensor) -> Tensor: return out_prod -class Basis: # Should this be a module? No stored data. - def __init__(self, n: int, basis: Callable[[Tensor, int], Tensor]): - super().__init__() +class Basis: + # TODO: Is this the same as above? No! It is not! + def __init__(self, n: int, basis: Callable[[Tensor, int], float]): self.n = n self.basis = basis @@ -461,28 +461,15 @@ def interpolate(self, x: Tensor, w: Tensor) -> Tensor: Returns: - result: size[batch, output] """ - batch_size, input_size = x.shape - _, _, output_size, basis_size = w.shape - - # Create a tensor to hold the basis functions, shape: [batch, input, n] - basis_functions = torch.empty((batch_size, input_size, self.n), device=x.device) - - # Calculate all basis functions for each j in parallel + mat = [] for j in range(self.n): - basis_functions[:, :, j] = self.basis(x, j) - - # Reshape basis functions to match the required shape for einsum - # basis_functions: [batch, input, n] -> [n, batch, input] - basis_functions = basis_functions.permute(2, 0, 1) + basis_j = self.basis(x, j) + mat.append(basis_j) + mat = torch.stack(mat) - # Perform the einsum operation to calculate the result - # einsum equation explanation: - # - 'ijk' corresponds to basis_functions with shape [n, batch, input] - # - 'jkli' corresponds to w with shape [batch, input, output, basis] - # - 'jl' corresponds to the output with shape [batch, output] - result = torch.einsum("ijk,jkli->jl", basis_functions, w) + out_sum = torch.einsum("ijk,jkli->jl", mat, w) - return result + return out_sum class BasisProd: diff --git a/high_order_layers_torch/LagrangePolynomial.py b/high_order_layers_torch/LagrangePolynomial.py index a0d61ea..f1a44a8 100644 --- a/high_order_layers_torch/LagrangePolynomial.py +++ b/high_order_layers_torch/LagrangePolynomial.py @@ -20,7 +20,7 @@ def chebyshevLobatto(n: int): if n == 1: return torch.tensor([0.0]) - return torch.cos(torch.pi * torch.arange(n) / (n - 1)) + return -torch.cos(torch.pi * torch.arange(n) / (n - 1)) class FourierBasis: