Skip to content

Commit

Permalink
LagrangeBasis2d working
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 18, 2024
1 parent 47a0a86 commit e1d9a48
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
28 changes: 17 additions & 11 deletions high_order_layers_torch/LagrangePolynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ def __call__(self, x, j: int):
return ans


class LagrangeBasisND:
"""
TODO: NOT IMPLEMENTED
N Dimensional version of the lagrange polynomial basis
"""
class LagrangeBasis2D:

def __init__(self, n: int, length: float = 2.0, dimensions: int = 2):
self.n = n
Expand All @@ -93,14 +89,24 @@ def _compute_denominators(self):
denom[j, m] = self.X[j] - self.X[m]
return denom

def __call__(self, x, j: int):
return NotImplementedError
x_diff = x.unsqueeze(-1) - self.X # Ensure broadcasting
def __call__(self, x, j: int, k: int):
x_diff = x.unsqueeze(-1) - self.X

b = torch.where(
torch.arange(self.n) != j, x_diff / self.denominators[j], torch.tensor(1.0)
torch.arange(self.n) != j,
x_diff[:, 0, :] / self.denominators[j],
torch.tensor(1.0),
)
ans = torch.prod(b, dim=-1)
return ans
c = torch.where(
torch.arange(self.n) != k,
x_diff[:, 1, :] / self.denominators[k],
torch.tensor(1.0),
)

r1 = torch.prod(b, dim=-1)
r2 = torch.prod(c, dim=-1)

return r1 * r2


class LagrangeBasis1:
Expand Down
17 changes: 15 additions & 2 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from high_order_layers_torch.FunctionalConvolution import *
from high_order_layers_torch.LagrangePolynomial import *
from high_order_layers_torch.LagrangePolynomial import LagrangePoly
from high_order_layers_torch.LagrangePolynomial import LagrangePoly, LagrangeBasis2D
import torch.nn.functional as F
from high_order_layers_torch.layers import (
L2Normalization,
Expand Down Expand Up @@ -38,11 +38,24 @@ def test_variable_dimension_input(n, in_features, out_features, segments):
layer(a)
"""

def test_lagrange_basis() :
lb = LagrangeBasis2D(n=5, dimensions=2)
x=torch.tensor([[-1,-1],[1,1]])

res = lb(x=x, j=0, k=0)
print('res1', res)
assert res[0]==1
assert torch.abs(res[1])<1e-12

res = lb(x=x, j=4, k=4)
print('res2', res)
assert res[1]==1
assert torch.abs(res[0])<1e-12

def test_nodes():
ans = chebyshevLobatto(20)
assert ans.shape[0] == 20


def test_polynomial():
poly = LagrangePoly(5)
# Just use the points as the actual values
Expand Down

0 comments on commit e1d9a48

Please sign in to comment.