Skip to content

Commit

Permalink
More simplification, fix sign error
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 15, 2024
1 parent 551ff3c commit d20253d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 23 deletions.
31 changes: 9 additions & 22 deletions high_order_layers_torch/Basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,9 @@ def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
return out_prod


class Basis: # Should this be a module? No stored data.
def __init__(self, n: int, basis: Callable[[Tensor, int], Tensor]):
super().__init__()
class Basis:
# TODO: Is this the same as above? No! It is not!
def __init__(self, n: int, basis: Callable[[Tensor, int], float]):
self.n = n
self.basis = basis

Expand All @@ -461,28 +461,15 @@ def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
Returns:
- result: size[batch, output]
"""
batch_size, input_size = x.shape
_, _, output_size, basis_size = w.shape

# Create a tensor to hold the basis functions, shape: [batch, input, n]
basis_functions = torch.empty((batch_size, input_size, self.n), device=x.device)

# Calculate all basis functions for each j in parallel
mat = []
for j in range(self.n):
basis_functions[:, :, j] = self.basis(x, j)

# Reshape basis functions to match the required shape for einsum
# basis_functions: [batch, input, n] -> [n, batch, input]
basis_functions = basis_functions.permute(2, 0, 1)
basis_j = self.basis(x, j)
mat.append(basis_j)
mat = torch.stack(mat)

# Perform the einsum operation to calculate the result
# einsum equation explanation:
# - 'ijk' corresponds to basis_functions with shape [n, batch, input]
# - 'jkli' corresponds to w with shape [batch, input, output, basis]
# - 'jl' corresponds to the output with shape [batch, output]
result = torch.einsum("ijk,jkli->jl", basis_functions, w)
out_sum = torch.einsum("ijk,jkli->jl", mat, w)

return result
return out_sum


class BasisProd:
Expand Down
2 changes: 1 addition & 1 deletion high_order_layers_torch/LagrangePolynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def chebyshevLobatto(n: int):
if n == 1:
return torch.tensor([0.0])

return torch.cos(torch.pi * torch.arange(n) / (n - 1))
return -torch.cos(torch.pi * torch.arange(n) / (n - 1))


class FourierBasis:
Expand Down

0 comments on commit d20253d

Please sign in to comment.