From 6bfbc81cdc90540adbf65905afc53b6cabd461c8 Mon Sep 17 00:00:00 2001 From: "Addison A." <70176208+a-alveyblanc@users.noreply.github.com> Date: Fri, 11 Oct 2024 16:28:57 -0500 Subject: [PATCH] Use `TensorProductQuadrature` instead of `Quadrature` in `TensorProductElementGroupBase` (#436) * create TensorProductQuadrature by default instead of Quadrature in TP element groups * use all 1d bases and 1d nodes to create quadrature * fix ruff complaints * unpack in loop header Co-authored-by: Alex Fikl * small change * clean ups; check that nodes match --------- Co-authored-by: Alex Fikl --- meshmode/discretization/poly_element.py | 30 ++++++++++++++++++++----- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/meshmode/discretization/poly_element.py b/meshmode/discretization/poly_element.py index 9b830e65..f6524306 100644 --- a/meshmode/discretization/poly_element.py +++ b/meshmode/discretization/poly_element.py @@ -525,12 +525,30 @@ def basis_obj(self): @memoize_method def quadrature_rule(self): - basis = self._basis - nodes = self._nodes - mass_matrix = mp.mass_matrix(basis, nodes) - weights = np.dot(mass_matrix, - np.ones(len(basis.functions))) - return mp.Quadrature(nodes, weights, exact_to=self.order) + from modepy.tools import reshape_array_for_tensor_product_space + + quads = [] + + if self.dim != 1: + nodes_tp = reshape_array_for_tensor_product_space(self.space, + self._nodes) + else: + nodes_tp = self._nodes + + for idim, (nodes, basis) in enumerate(zip(nodes_tp, self._basis.bases)): + # get current dimension's nodes from fastest varying axis + nodes = nodes[*(0,)*idim, :, *(0,)*(self.dim-idim-1)] + + nodes_1d = nodes.reshape(1, -1) + mass_matrix = mp.mass_matrix(basis, nodes_1d) + weights = np.dot(mass_matrix, np.ones(len(basis.functions))) + + quads.append(mp.Quadrature(nodes_1d, weights, exact_to=self.order)) + + tp_quad = mp.TensorProductQuadrature(quads) + assert np.allclose(tp_quad.nodes, self._nodes) + + return tp_quad @property @memoize_method