Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 23, 2024
1 parent 02ef5f6 commit c14a471
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions high_order_layers_torch/PolynomialLayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,9 @@ def __init__(
# Calculate total number of weights needed
self.weights_per_segment = math.prod(self.n)
self.total_segments = math.prod(self.segments) # per block
total_weights = (
in_features * out_features * self.total_segments * self.weights_per_segment
)


# Ahh! This is actually the discontinuous case as we haven't
# accounted for neighboring nodes
self.w = nn.Parameter(
torch.empty(
in_features,
Expand Down Expand Up @@ -502,10 +501,10 @@ def _reshape_weights(self, weight_indices):
)

# Select weights for each input point
selected_weights = self.w[weight_indices]
#selected_weights = self.w[weight_indices]

# Reshape to [out_features, num_inputs, batch_size, weights_per_segment]
reshaped_weights = selected_weights.view(
reshaped_weights = self.w[weight_indices].view(
batch_size, num_inputs, self.out_features, self.weights_per_segment
)

Expand Down

0 comments on commit c14a471

Please sign in to comment.