From c14a471ecceb0957d82a46d0b3fadff6563203a3 Mon Sep 17 00:00:00 2001 From: jloveric Date: Sat, 22 Jun 2024 21:10:14 -0700 Subject: [PATCH] Cleanup --- high_order_layers_torch/PolynomialLayers.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/high_order_layers_torch/PolynomialLayers.py b/high_order_layers_torch/PolynomialLayers.py index a214d29..860e129 100644 --- a/high_order_layers_torch/PolynomialLayers.py +++ b/high_order_layers_torch/PolynomialLayers.py @@ -388,10 +388,9 @@ def __init__( # Calculate total number of weights needed self.weights_per_segment = math.prod(self.n) self.total_segments = math.prod(self.segments) # per block - total_weights = ( - in_features * out_features * self.total_segments * self.weights_per_segment - ) - + + # Ahh! This is actually the discontinuous case as we haven't + # accounted for neighboring nodes self.w = nn.Parameter( torch.empty( in_features, @@ -502,10 +501,10 @@ def _reshape_weights(self, weight_indices): ) # Select weights for each input point - selected_weights = self.w[weight_indices] + #selected_weights = self.w[weight_indices] # Reshape to [out_features, num_inputs, batch_size, weights_per_segment] - reshaped_weights = selected_weights.view( + reshaped_weights = self.w[weight_indices].view( batch_size, num_inputs, self.out_features, self.weights_per_segment )