From f4bfa50fb678ed3ba6ac29a5e5c031dc01dd9072 Mon Sep 17 00:00:00 2001 From: jloveric Date: Mon, 17 Jun 2024 18:04:56 -0700 Subject: [PATCH] Convert to N dimensional basis! --- tests/test_layers.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 9efcea0..d7f42be 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -38,25 +38,28 @@ def test_variable_dimension_input(n, in_features, out_features, segments): layer(a) """ -@pytest.mark.parametrize("dimensions", [1,2,3,4]) -def test_lagrange_basis(dimensions) : + +@pytest.mark.parametrize("dimensions", [1, 2, 3, 4]) +def test_lagrange_basis(dimensions): lb = LagrangeBasisND(n=5, dimensions=dimensions) - x=torch.tensor([[-1]*dimensions,[1]*dimensions]) - - res = lb(x=x, index=[0]*dimensions) - print('res1', res) - assert res[0]==1 - assert torch.abs(res[1])<1e-12 - - res = lb(x=x, index=[4]*dimensions) - print('res2', res) - assert res[1]==1 - assert torch.abs(res[0])<1e-12 + x = torch.tensor([[-1] * dimensions, [1] * dimensions]) + + res = lb(x=x, index=[0] * dimensions) + print("res1", res) + assert res[0] == 1 + assert torch.abs(res[1]) < 1e-12 + + res = lb(x=x, index=[4] * dimensions) + print("res2", res) + assert res[1] == 1 + assert torch.abs(res[0]) < 1e-12 + def test_nodes(): ans = chebyshevLobatto(20) assert ans.shape[0] == 20 + def test_polynomial(): poly = LagrangePoly(5) # Just use the points as the actual values