Skip to content

Commit

Permalink
Add polynomial nd layers and one example
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 19, 2024
1 parent 68fac25 commit 78f6ba7
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 56 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,12 @@ With piecewise continuous.
![piecewise continuous polynomial](plots/xor_continuous.png)
With polynomial using similar number of parameters.
![polynomial](plots/xor_polynomial.png)

Using 2D polynomial "layer" this is just a single input and single output. The polynomial 2d is a link that takes in 2 variables (a 2d vector) and in this
case outputs a single value. It's the cartesian product of the basis function of a single polynomial function, but in the x and y direction. Using 5 point polynomials would have 25 basis functions in 2d whereas it has only 5 basis functions in 1d.
```
python3 examples/xor.py layer_type=polynomial_2d optimizer.lr=0.01 optimizer=sophia epochs=100
```
![polynomial 2d](plots/xor_polynomial_2d.png)
## MNIST (convolutional)

```python
Expand Down
8 changes: 4 additions & 4 deletions config/xor.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
layer_type : continuous # discontinuous, polynomial
layer_type : continuous # discontinuous, polynomial, polynomial_2d
segments: 2
epochs: 40
optimizer: sophia #lion, adam
lr: 0.1
batch_size: 32
batch_size: 32
defaults:
- optimizer: sophia, #lion, adam
125 changes: 83 additions & 42 deletions examples/xor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@


class XorDataset(Dataset):
def __init__(self, transform=None):
def __init__(self, transform=None, nd:bool=False):
x = (2.0 * torch.rand(1000) - 1.0).view(-1, 1)
y = (2.0 * torch.rand(1000) - 1.0).view(-1, 1)
z = torch.where(x * y > 0, -0.5 + 0 * x, 0.5 + 0 * x)

self.data = torch.cat([x, y], dim=1)
if nd is True :
self.data = self.data.unsqueeze(1)

self.z = z
print(self.data.shape)

Expand All @@ -58,7 +61,7 @@ def __init__(
optimizer: str = "sophia",
lr: float = 0.01,
batch_size: int = 32,
device='cpu',
device="cpu",
):
"""
Simple network consisting of 2 input and 1 output
Expand All @@ -69,29 +72,44 @@ def __init__(
self.optimizer = optimizer
self.lr = lr
self.batch_size = batch_size

self.layer1 = high_order_fc_layers(
layer_type=layer_type,
n=n,
in_features=2,
out_features=2,
segments=segments,
alpha=linear_part,
intialization="constant_random",
device=device,
)
self.layer2 = high_order_fc_layers(
layer_type=layer_type,
n=n,
in_features=2,
out_features=1,
segments=segments,
alpha=linear_part,
initialization="constant_random",
)
self.layer_type = layer_type

if layer_type == "polynomial_2d":
layer1 = high_order_fc_layers(
layer_type=layer_type,
n=n,
in_features=2,
out_features=1,
segments=segments,
alpha=linear_part,
intialization="constant_random",
device=device,
)
self.model = nn.Sequential(*[layer1])
else:
layer1 = high_order_fc_layers(
layer_type=layer_type,
n=n,
in_features=2,
out_features=2,
segments=segments,
alpha=linear_part,
intialization="constant_random",
device=device,
)
layer2 = high_order_fc_layers(
layer_type=layer_type,
n=n,
in_features=2,
out_features=1,
segments=segments,
alpha=linear_part,
initialization="constant_random",
)
self.model = nn.Sequential(*[layer1, layer2])

def forward(self, x):
out1 = self.layer2(self.layer1(x))
out1 = self.model(x)
return out1

def training_step(self, batch, batch_idx):
Expand All @@ -100,6 +118,10 @@ def training_step(self, batch, batch_idx):
return {"loss": F.mse_loss(y_hat, y)}

def train_dataloader(self):

if self.layer_type == "polynomial_2d" :
return DataLoader(XorDataset(nd=True), batch_size=self.batch_size)

return DataLoader(XorDataset(), batch_size=self.batch_size)

def configure_optimizers(self):
Expand All @@ -109,23 +131,28 @@ def configure_optimizers(self):
return SophiaG(self.parameters(), lr=self.lr, rho=0.035)
elif self.optimizer == "adam":
return torch.optim.Adam(self.parameters(), lr=self.lr)
else :
raise ValueError(f"optimizer must be lion, sophia or adam, got {self.optimizer}")
else:
raise ValueError(
f"optimizer must be lion, sophia or adam, got {self.optimizer}"
)


model_set_p = [
{"name": f"Polynomial {i+1}", "order": i + 1, "layer": "polynomial"}
for i in range(1, 9, 2)
{"name": f"Polynomial {i+1}", "n": i, "layer": "polynomial"}
for i in range(2, 9, 2)
]
model_set_c = [
{"name": f"Continuous {i+1}", "order": i + 1, "layer": "continuous"}
for i in range(1, 5)
{"name": f"Continuous {i+1}", "n": i, "layer": "continuous"}
for i in range(2, 6)
]
model_set_d = [
{"name": f"Discontinuous {i+1}", "order": i + 1, "layer": "discontinuous"}
for i in range(1, 5)
{"name": f"Discontinuous {i+1}", "n": i, "layer": "discontinuous"}
for i in range(1, 6)
]
model_set_2d = [
{"name": f"Polynomial 2D {i+1}", "n": i, "layer": "polynomial_2d"}
for i in range(2, 10, 2)
]


def plot_approximation(
model_set,
Expand All @@ -137,36 +164,49 @@ def plot_approximation(
optimizer="sophia",
lr=0.01,
batch_size=32,
device='cpu'
device="cpu",
):
global xTest

pred_set = []
for i in range(0, len(model_set)):
layer_type = model_set[i]["layer"]

trainer = Trainer(max_epochs=epochs, accelerator=device)
model = NDFunctionApproximation(
n=model_set[i]["order"],
n=model_set[i]["n"],
segments=segments,
layer_type=model_set[i]["layer"],
layer_type=layer_type,
linear_part=linear_part,
optimizer=optimizer,
lr=lr,
batch_size=batch_size,
device=device,
)

trainer.fit(model)
predictions = model(xTest.view(xTest.size(0), -1))
if layer_type == "polynomial_2d" :
thisTest = xTest.reshape(xTest.size(0),1, -1)
print('xtest.shape', thisTest.shape)
predictions = model(thisTest)
else :
thisTest = xTest.reshape(xTest.size(0), -1)

predictions = model(thisTest)
pred_set.append(predictions)
if plot is True:
ans = predictions.flatten().data.numpy()
xTest = xTest.reshape(xTest.size(0),-1)
plt.subplot(2, 2, i + 1)
plt.scatter(
xTest.data.numpy()[:, 0],
xTest.data.numpy()[:, 1],
c=predictions.flatten().data.numpy(),
)
if model_set[i]["layer"] != "polynomial":
plt.title(f"{model_set[i]['name']} with {segments} segments.")
if model_set[i]["layer"] not in [ "polynomial", "polynomial_2d"]:
plt.title(f"{model_set[i]['name']} with {segments} segments")
else:
plt.title(f"{model_set[i]['name']}.")
plt.title(f"{model_set[i]['name']}")

return pred_set

Expand All @@ -178,6 +218,7 @@ def run(cfg: DictConfig):
"continuous": model_set_c,
"discontinuous": model_set_d,
"polynomial": model_set_p,
"polynomial_2d" : model_set_2d,
}

plot_approximation(
Expand All @@ -186,9 +227,9 @@ def run(cfg: DictConfig):
epochs=cfg.epochs,
linear_part=0,
plot=True,
optimizer=cfg.optimizer,
lr=cfg.lr,
batch_size=cfg.batch_size
optimizer=cfg.optimizer.name,
lr=cfg.optimizer.lr,
batch_size=cfg.batch_size,
)
plt.show()

Expand Down
8 changes: 4 additions & 4 deletions high_order_layers_torch/Basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ class BasisFlat:
def __init__(self, n: int, basis: Callable[[Tensor, int], float]):
self.n = n
self.basis = basis
self.num_basis = basis.num_basis or n

def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -420,21 +421,20 @@ def __init__(
self.dimensions = dimensions
a = torch.arange(n)
self.indexes = torch.stack(torch.meshgrid([a]*dimensions)).reshape(dimensions, -1).T.long()
self.num_basis = basis.num_basis

def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
"""
:param x: size[batch, input, dimension]
:param w: size[input, output, basis]
:param w: size[output, input, basis]
:returns: size[batch, output]
"""

basis = []
for index in self.indexes:
basis_j = self.basis(x, index=index)
basis.append(basis_j)
basis = torch.stack(basis)

out_sum = torch.einsum("ijk,kli->jl", basis, w)
out_sum = torch.einsum("ijk,lki->jl", basis, w)

return out_sum

Expand Down
3 changes: 3 additions & 0 deletions high_order_layers_torch/LagrangePolynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, length: float):
of 1 means there is periodicity 1
"""
self.length = length
self.num_basis = None # Apparently defined elsewhere? How does this work!

def __call__(self, x: Tensor, j: int):
"""
Expand All @@ -55,6 +56,7 @@ def __init__(self, n: int, length: float = 2.0):
self.n = n
self.X = (length / 2.0) * chebyshevLobatto(n)
self.denominators = self._compute_denominators()
self.num_basis = n

def _compute_denominators(self):
denom = torch.ones((self.n, self.n), dtype=torch.float32)
Expand Down Expand Up @@ -119,6 +121,7 @@ class LagrangeBasis1:
def __init__(self, length: float = 2.0):
self.n = 1
self.X = torch.tensor([0.0])
self.num_basis=1

def __call__(self, x, j: int):
b = torch.ones_like(x)
Expand Down
Loading

0 comments on commit 78f6ba7

Please sign in to comment.