Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 23, 2024
1 parent eb63b09 commit 02ef5f6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
36 changes: 27 additions & 9 deletions high_order_layers_torch/PolynomialLayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,25 @@ def __init__(
self._length = length
self._half = 0.5 * length

self.lagrange_basis = LagrangeBasisPiecewiseND(self.n, length=length, device=device)
self.lagrange_basis = LagrangeBasisPiecewiseND(
self.n, length=length, device=device
)

# Calculate total number of weights needed
self.weights_per_segment = math.prod(self.n)
self.total_segments = math.prod(self.segments) # per block
total_weights = in_features*out_features * self.total_segments * self.weights_per_segment
self.total_segments = math.prod(self.segments) # per block
total_weights = (
in_features * out_features * self.total_segments * self.weights_per_segment
)

self.w = nn.Parameter(torch.empty(total_weights, device=device))
self.w = nn.Parameter(
torch.empty(
in_features,
out_features,
self.total_segments * self.weights_per_segment,
device=device,
)
)

if initialize == "constant_random":
self._constant_random_initialization(weight_magnitude)
Expand All @@ -400,12 +411,14 @@ def __init__(
def _constant_random_initialization(self, weight_magnitude):
# TODO: verify this.
segment_values = (
torch.rand(self.out_features, self.total_segments, device=self.device)
torch.rand(self.in_features, self.out_features, device=self.device)
* 2
* weight_magnitude
- weight_magnitude
) / self.in_features
self.w.data = segment_values.repeat_interleave(
self.total_segments * self.weights_per_segment
)
self.w.data = segment_values.repeat_interleave(self.weights_per_segment)

def which_segment(self, x: torch.Tensor) -> torch.Tensor:
return (
Expand All @@ -415,7 +428,10 @@ def which_segment(self, x: torch.Tensor) -> torch.Tensor:
* torch.tensor(self.segments, device=self.device)
)
.long()
.clamp(torch.tensor(0, device=self.device), torch.tensor(self.segments, device=self.device) - 1)
.clamp(
torch.tensor(0, device=self.device),
torch.tensor(self.segments, device=self.device) - 1,
)
)

def x_local(self, x_global: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -481,7 +497,9 @@ def _get_weight_indices(self, segment_indices):
def _reshape_weights(self, weight_indices):
# Reshape weights based on weight indices
batch_size, num_inputs, _ = weight_indices.shape
weight_indices = weight_indices.unsqueeze(2).expand(-1, -1, self.out_features, -1)
weight_indices = weight_indices.unsqueeze(2).expand(
-1, -1, self.out_features, -1
)

# Select weights for each input point
selected_weights = self.w[weight_indices]
Expand All @@ -490,7 +508,7 @@ def _reshape_weights(self, weight_indices):
reshaped_weights = selected_weights.view(
batch_size, num_inputs, self.out_features, self.weights_per_segment
)

# Permute to get the desired shape [out_features, num_inputs, batch_size, weights_per_segment]
return reshaped_weights

Expand Down
10 changes: 0 additions & 10 deletions high_order_layers_torch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.optim as optim
import torch_optimizer as alt_optim
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from torch import Tensor
Expand Down Expand Up @@ -53,15 +52,6 @@ def test_step(self, batch, batch_idx):
return self.eval_step(batch, "test")

def configure_optimizers(self):
if self.cfg.optimizer.name == "adahessian":
return alt_optim.Adahessian(
self.parameters(),
lr=self.cfg.optimizer.lr,
betas=self.cfg.optimizer.betas,
eps=self.cfg.optimizer.eps,
weight_decay=self.cfg.optimizer.weight_decay,
hessian_power=self.cfg.optimizer.hessian_power,
)

if self.cfg.optimizer.name == "adam" :
optimizer = optim.Adam(
Expand Down

0 comments on commit 02ef5f6

Please sign in to comment.