Skip to content

Commit

Permalink
Speeding up the ND model
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 19, 2024
1 parent fc8b886 commit f6ad178
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 79 deletions.
69 changes: 34 additions & 35 deletions high_order_layers_torch/Basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,43 +408,42 @@ def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
return out_sum


class BasisFlatND:
"""
Single N dimensional element.
"""

def __init__(
self,
n: int,
basis: Callable[[Tensor, list[int]], float],
dimensions: int,
**kwargs
):
self.n = n
self.basis = basis
self.dimensions = dimensions
a = torch.arange(n)
self.indexes = (
torch.stack(torch.meshgrid([a] * dimensions))
.reshape(dimensions, -1)
.T.long()
)
self.num_basis = basis.num_basis
# class BasisFlatND:

# def __init__(
# self,
# n: int,
# basis: Callable[[Tensor, list], float],
# dimensions: int,
# **kwargs
# ):
# self.n = n
# self.basis = basis
# self.dimensions = dimensions
# a = torch.arange(n)
# self.indexes = (
# torch.stack(torch.meshgrid([a] * dimensions))
# .reshape(dimensions, -1)
# .T.long()
# )
# self.num_basis = basis.num_basis

# def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
# """
# :param x: size[batch, input, dimension]
# :param w: size[output, input, basis]
# :returns: size[batch, output]
# """
# basis = []
# for index in self.indexes:
# basis_j = self.basis(x, index=index)
# basis.append(basis_j)
# basis = torch.stack(basis)
# out_sum = torch.einsum("ijk,lki->jl", basis, w)

# return out_sum

def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
"""
:param x: size[batch, input, dimension]
:param w: size[output, input, basis]
:returns: size[batch, output]
"""
basis = []
for index in self.indexes:
basis_j = self.basis(x, index=index)
basis.append(basis_j)
basis = torch.stack(basis)
out_sum = torch.einsum("ijk,lki->jl", basis, w)

return out_sum


class BasisFlatProd:
Expand Down
111 changes: 69 additions & 42 deletions high_order_layers_torch/LagrangePolynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,64 @@ def chebyshevLobatto(n: int):
return -torch.cos(torch.pi * torch.arange(n) / (n - 1))


class LagrangeBasisND:
"""
Single N dimensional element with Lagrange basis interpolation.
"""
def __init__(
self, n: int, length: float = 2.0, dimensions: int = 2, device: str = "cpu", **kwargs
):
self.n = n
self.dimensions = dimensions
self.X = (length / 2.0) * chebyshevLobatto(n).to(device)
self.device = device
self.denominators = self._compute_denominators()
self.num_basis = int(math.pow(n, dimensions))

a = torch.arange(n)
self.indexes = (
torch.stack(torch.meshgrid([a] * dimensions, indexing="ij"))
.reshape(dimensions, -1)
.T.long().to(self.device)
)

def _compute_denominators(self):
X_diff = self.X.unsqueeze(0) - self.X.unsqueeze(1) # [n, n]
denom = torch.where(X_diff == 0, torch.tensor(1.0, device=self.device), X_diff)
return denom

def _compute_basis(self, x, indexes):
"""
Computes the basis values for all index combinations.
:param x: [batch, inputs, dimensions]
:param indexes: [num_basis, dimensions]
:returns: basis values [num_basis, batch, inputs]
"""
x_diff = x.unsqueeze(-1) - self.X # [batch, inputs, dimensions, n]
mask = (indexes.unsqueeze(1).unsqueeze(2).unsqueeze(4) != torch.arange(self.n, device=self.device).view(1, 1, 1, 1, self.n))
denominators = self.denominators[indexes] # [num_basis, dimensions, n]

b = torch.where(mask, x_diff.unsqueeze(0) / denominators.unsqueeze(1).unsqueeze(2), torch.tensor(1.0, device=self.device))
#print('b.shape', b.shape)
r = torch.prod(b, dim=-1) # [num_basis, batch, inputs, dimensions]

return r.prod(dim=-1) # [num_basis, batch, inputs]

def interpolate(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""
Interpolates the input using the Lagrange basis.
:param x: size[batch, inputs, dimensions]
:param w: size[output, inputs, num_basis]
:returns: size[batch, output]
"""
basis = self._compute_basis(x, self.indexes) # [num_basis, batch, inputs]
#print('bassis.shape', basis.shape, 'w.shape', w.shape)
out_sum = torch.einsum("ibk,oki->bo", basis, w) # [batch, output]

return out_sum



class FourierBasis:
def __init__(self, length: float):
"""
Expand Down Expand Up @@ -75,47 +133,6 @@ def __call__(self, x, j: int):
return ans


class LagrangeBasisND:

def __init__(
self, n: int, length: float = 2.0, dimensions: int = 2, device: str = "cpu", **kwargs
):
self.n = n
self.dimensions = dimensions
self.X = (length / 2.0) * chebyshevLobatto(n).to(device)
self.device = device
self.denominators = self._compute_denominators()
self.num_basis = int(math.pow(n, dimensions))

def _compute_denominators(self):

X_diff = self.X.unsqueeze(0) - self.X.unsqueeze(1) # [n, n]
denom = torch.where(X_diff == 0, torch.tensor(1.0, device=self.device), X_diff)
return denom

def __call__(self, x, index: list):
"""
TODO: I believe we can make this even more efficient if we
calculate all basis at once instead of one at a time and
we'll be able to do the whole thing as a cartesian product,
but this is pretty fast - O(n)*O(dims). The x_diff computation
is redundant as it's the same for every basis. This function
will be called O(n^dims) times so O(dims*n^(dims+1))
:param x: [batch, inputs, dimensions]
:param index : [dimensions]
:returns: basis value [batch, inputs]
"""
x_diff = x.unsqueeze(-1) - self.X # [batch, inputs, dimensions, n]
indices = torch.tensor(index, device=self.device).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
mask = torch.arange(self.n, device=self.device).unsqueeze(0).unsqueeze(0).unsqueeze(0) != indices
denominators = self.denominators[index] # [dimensions, n]

b = torch.where(mask, x_diff / denominators, torch.tensor(1.0, device=self.device))
r = torch.prod(b, dim=-1) # [batch, inputs, dimensions]

return r.prod(dim=-1)


class LagrangeBasis1:
"""
TODO: Degenerate case, test this and see if it works with everything else.
Expand Down Expand Up @@ -181,7 +198,7 @@ class LagrangePolyFlat(BasisFlat):
def __init__(self, n: int, length: float = 2.0, **kwargs):
super().__init__(n, get_lagrange_basis(n, length), **kwargs)


"""
class LagrangePolyFlatND(BasisFlatND):
def __init__(self, n: int, length: float = 2.0, dimensions: int = 2, **kwargs):
super().__init__(
Expand All @@ -190,6 +207,16 @@ def __init__(self, n: int, length: float = 2.0, dimensions: int = 2, **kwargs):
dimensions=dimensions,
**kwargs
)
"""

class LagrangePolyFlatND(LagrangeBasisND):
def __init__(self, n: int, length: float = 2.0, dimensions: int = 2, **kwargs):
super().__init__(
n,
length=length,
dimensions=dimensions,
**kwargs
)


class LagrangePolyFlatProd(BasisFlatProd):
Expand Down
7 changes: 5 additions & 2 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from high_order_layers_torch.networks import *
from high_order_layers_torch.PolynomialLayers import *
import torch
from high_order_layers_torch.Basis import BasisFlatND
#from high_order_layers_torch.Basis import BasisFlatND

torch.set_default_device(device="cpu")

Expand All @@ -39,6 +39,9 @@ def test_variable_dimension_input(n, in_features, out_features, segments):
layer(a)
"""

"""
These have both been combined into the new LagrangeBasisND so
the computation is faster.
def test_basis_nd() :
dimensions = 3
n=5
Expand Down Expand Up @@ -75,7 +78,7 @@ def test_lagrange_basis(dimensions):
print("res2", res)
assert res[1] == 1
assert torch.abs(res[0]) < 1e-12

"""

def test_nodes():
ans = chebyshevLobatto(20)
Expand Down

0 comments on commit f6ad178

Please sign in to comment.