Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass strict=True in zip() where applicable #440

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/moving-geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def source(t, x):
gradx = sum(
num_reference_derivative(discr, (i,), x)
for i in range(discr.dim))
intx = sum(actx.np.sum(xi * wi) for xi, wi in zip(x, discr.quad_weights()))
intx = sum(
actx.np.sum(xi * wi)
for xi, wi in zip(x, discr.quad_weights(), strict=True))

assert gradx is not None
assert intx is not None
Expand Down
8 changes: 5 additions & 3 deletions examples/simple-dg.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def grad(self, vec):
for idim in range(self.volume_discr.dim)]

return make_obj_array([
sum(dref_i*ipder_i for dref_i, ipder_i in zip(dref, ipder[iambient]))
sum(dref_i*ipder_i
for dref_i, ipder_i in zip(dref, ipder[iambient], strict=True))
for iambient in range(self.volume_discr.ambient_dim)])

def div(self, vecs):
Expand Down Expand Up @@ -259,7 +260,7 @@ def inverse_mass(self, vec):
vec_i,
arg_names=("mass_inv_mat", "vec"),
tagged=(FirstAxisIsElementsTag(),)
) for grp, vec_i in zip(discr.groups, vec)
) for grp, vec_i in zip(discr.groups, vec, strict=True)
)
) / actx.thaw(self.vol_jacobian())

Expand Down Expand Up @@ -321,7 +322,8 @@ def face_mass(self, vec):
),
tagged=(FirstAxisIsElementsTag(),))
for afgrp, volgrp, vec_i in zip(all_faces_discr.groups,
vol_discr.groups, vec)
vol_discr.groups, vec,
strict=True)
)
)

Expand Down
2 changes: 1 addition & 1 deletion meshmode/discretization/connection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def check_connection(actx: ArrayContext, connection: DirectDiscretizationConnect

assert len(connection.groups) == len(to_discr.groups)

for cgrp, tgrp in zip(connection.groups, to_discr.groups):
for cgrp, tgrp in zip(connection.groups, to_discr.groups, strict=True):
for batch in cgrp.batches:
fgrp = from_discr.groups[batch.from_group_index]

Expand Down
4 changes: 2 additions & 2 deletions meshmode/discretization/connection/chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _build_batches(actx, from_bins, to_bins, batch):
def to_device(x):
return actx.freeze(actx.from_numpy(np.asarray(x)))

for ibatch, (from_bin, to_bin) in enumerate(zip(from_bins, to_bins)):
for ibatch, (from_bin, to_bin) in enumerate(zip(from_bins, to_bins, strict=True)):
yield InterpolationBatch(
from_group_index=batch[ibatch].from_group_index,
from_element_indices=to_device(from_bin),
Expand Down Expand Up @@ -248,7 +248,7 @@ def flatten_chained_connection(actx, connection):

# build new groups
groups = []
for igrp, (from_bin, to_bin) in enumerate(zip(from_bins, to_bins)):
for igrp, (from_bin, to_bin) in enumerate(zip(from_bins, to_bins, strict=True)):
groups.append(DiscretizationConnectionElementGroup(
list(_build_batches(actx, from_bin, to_bin,
batch_info[igrp]))))
Expand Down
4 changes: 2 additions & 2 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def group_pick_knl(is_surjective: bool):

group_arrays = []
for i_tgrp, (cgrp, group_pick_info) in enumerate(
zip(self.groups, self._global_point_pick_info(actx))):
zip(self.groups, self._global_point_pick_info(actx), strict=True)):

group_array_contributions = []

Expand Down Expand Up @@ -925,7 +925,7 @@ def knl():
tgt_node_nr_base = 0
mats = []
for i_tgrp, (tgrp, cgrp) in enumerate(
zip(conn.to_discr.groups, conn.groups)):
zip(conn.to_discr.groups, conn.groups, strict=True)):
for i_batch, batch in enumerate(cgrp.batches):
if not len(batch.from_element_indices):
continue
Expand Down
7 changes: 4 additions & 3 deletions meshmode/discretization/connection/face.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def make_face_restriction(
connection_data = {}

for igrp, (grp, fagrp_list) in enumerate(
zip(discr.groups, discr.mesh.facial_adjacency_groups)):
zip(discr.groups, discr.mesh.facial_adjacency_groups, strict=True)):

mgrp = grp.mesh_el_group

Expand All @@ -252,7 +252,7 @@ def make_face_restriction(
if isinstance(fagrp, InteriorAdjacencyGroup)]
for fagrp in int_grps:
group_boundary_faces.extend(
zip(fagrp.elements, fagrp.element_faces))
zip(fagrp.elements, fagrp.element_faces, strict=True))

elif boundary_tag is FACE_RESTR_ALL:
group_boundary_faces.extend(
Expand All @@ -271,7 +271,8 @@ def make_face_restriction(
group_boundary_faces.extend(
zip(
bdry_grp.elements,
bdry_grp.element_faces))
bdry_grp.element_faces,
strict=True))

# }}}

Expand Down
2 changes: 1 addition & 1 deletion meshmode/discretization/connection/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def vandermonde_matrix(grp):
c_i,
arg_names=("vdm", "coeffs"),
tagged=(FirstAxisIsElementsTag(),))
for grp, c_i in zip(self.to_discr.groups, coefficients)
for grp, c_i in zip(self.to_discr.groups, coefficients, strict=True)
)
)

Expand Down
6 changes: 3 additions & 3 deletions meshmode/discretization/connection/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _build_interpolation_batches_for_group(
assert len(refinement_result) == num_children
# Refined -> interpolates to children
for from_bin, to_bin, child_idx in zip(
from_bins[1:], to_bins[1:], refinement_result):
from_bins[1:], to_bins[1:], refinement_result, strict=True):
from_bin.append(elt_idx)
to_bin.append(child_idx)

Expand All @@ -98,7 +98,7 @@ def _build_interpolation_batches_for_group(
from itertools import chain
for from_bin, to_bin, unit_nodes in zip(
from_bins, to_bins,
chain([fine_unit_nodes], mapped_unit_nodes)):
chain([fine_unit_nodes], mapped_unit_nodes), strict=True):
if not from_bin:
continue
yield InterpolationBatch(
Expand Down Expand Up @@ -149,7 +149,7 @@ def make_refinement_connection(actx, refiner, coarse_discr, group_factory):
groups = []
for group_idx, (coarse_discr_group, fine_discr_group, record) in \
enumerate(zip(coarse_discr.groups, fine_discr.groups,
refiner.group_refinement_records)):
refiner.group_refinement_records, strict=True)):
groups.append(
DiscretizationConnectionElementGroup(
list(_build_interpolation_batches_for_group(
Expand Down
3 changes: 2 additions & 1 deletion meshmode/discretization/connection/same_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def make_same_mesh_connection(actx, to_discr, from_discr):
return IdentityDiscretizationConnection(from_discr)

groups = []
for igrp, (fgrp, tgrp) in enumerate(zip(from_discr.groups, to_discr.groups)):
for igrp, (fgrp, tgrp) in enumerate(
zip(from_discr.groups, to_discr.groups, strict=True)):
from arraycontext.metadata import NameHint
all_elements = actx.tag(NameHint(f"all_el_ind_grp{igrp}"),
actx.tag_axis(0,
Expand Down
12 changes: 7 additions & 5 deletions meshmode/discretization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _check_discr_same_connectivity(discr, other):
if not all(
sg.discretization_key() == og.discretization_key()
and sg.nelements == og.nelements
for sg, og in zip(discr.groups, other.groups)):
for sg, og in zip(discr.groups, other.groups, strict=True)):
return False

return True
Expand Down Expand Up @@ -482,7 +482,8 @@ def cells(self):
grp.nunit_dofs,
grp.nelements * grp.nunit_dofs + 1,
grp.nunit_dofs)
for grp_offset, grp in zip(grp_offsets, self.vis_discr.groups)
for grp_offset, grp in zip(
grp_offsets, self.vis_discr.groups, strict=True)
])

return self.vis_discr.mesh.nelements, connectivity, offsets
Expand Down Expand Up @@ -1161,7 +1162,8 @@ def write_xdmf_file(self, file_name, names_and_fields,

grids = []
node_nr_base = 0
for igrp, (vgrp, gnodes) in enumerate(zip(connectivity.groups, nodes)):
for igrp, (vgrp, gnodes) in enumerate(
zip(connectivity.groups, nodes, strict=True)):
grp_name = f"Group_{igrp:05d}"
h5grp = h5grid.create_group(grp_name)

Expand Down Expand Up @@ -1318,7 +1320,7 @@ def make_visualizer(actx, discr, vis_order=None,
vis_discr = discr.copy(actx=actx, group_factory=VisGroupFactory(vis_order))

if all(grp.discretization_key() == vgrp.discretization_key()
for grp, vgrp in zip(discr.groups, vis_discr.groups)):
for grp, vgrp in zip(discr.groups, vis_discr.groups, strict=True)):
from warnings import warn
warn("Visualization discretization is identical to base discretization. "
"To avoid the creation of a separate discretization for "
Expand Down Expand Up @@ -1383,7 +1385,7 @@ def write_nodal_adjacency_vtk_file(file_name, mesh,
(mesh.ambient_dim, mesh.nelements),
dtype=mesh.vertices.dtype)

for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups):
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups, strict=True):
centroids[:, base_element_nr:base_element_nr + grp.nelements] = (
np.sum(mesh.vertices[:, grp.vertex_indices], axis=-1)
/ grp.vertex_indices.shape[-1])
Expand Down
2 changes: 1 addition & 1 deletion meshmode/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def complete_some(self):
raise ValueError(
"duplicate local/remote part pair in inter_rank_bdry_info")

for i_src_rank, recvd in zip(source_ranks, data):
for i_src_rank, recvd in zip(source_ranks, data, strict=True):
(remote_part_id, local_part_id,
remote_bdry_mesh, remote_group_infos) = recvd

Expand Down
2 changes: 1 addition & 1 deletion meshmode/dof_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def check_dofarray_against_discr(discr, dof_ary: DOFArray):
"DOFArray has unexpected number of groups "
f"({len(dof_ary)}, expected: {len(discr.groups)})")

for i, (grp, grp_ary) in enumerate(zip(discr.groups, dof_ary)):
for i, (grp, grp_ary) in enumerate(zip(discr.groups, dof_ary, strict=True)):
expected_shape = (grp.nelements, grp.nunit_dofs)
if grp_ary.shape != expected_shape:
raise InconsistentDOFArray(
Expand Down
13 changes: 9 additions & 4 deletions meshmode/interop/firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _get_firedrake_facial_adjacency_groups(fdrake_mesh_topology,
to_keep = np.isin(int_elements, cells_to_use)
cells_to_use_inv = dict(zip(cells_to_use,
np.arange(np.size(cells_to_use),
dtype=IntType)))
dtype=IntType), strict=True))

# Keep the cells that we are using and change old cell index
# to new cell index
Expand Down Expand Up @@ -459,7 +459,8 @@ def _get_firedrake_orientations(fdrake_mesh, unflipped_group, vertices,
orient = np.ones(num_cells)
if normals:
for i, (normal, vert_indices) in enumerate(
zip(np.array(normals), unflipped_group.vertex_indices)):
zip(np.array(normals), unflipped_group.vertex_indices, strict=True)
):
edge = vertices[:, vert_indices[1]] - vertices[:, vert_indices[0]]
if np.cross(normal, edge) < 0:
orient[i] = -1.0
Expand Down Expand Up @@ -621,7 +622,8 @@ def import_firedrake_mesh(fdrake_mesh, cells_to_use=None,
vert_ndx_new2old = np.unique(vertex_indices.flatten())
vert_ndx_old2new = dict(zip(vert_ndx_new2old,
np.arange(np.size(vert_ndx_new2old),
dtype=vertex_indices.dtype)))
dtype=vertex_indices.dtype),
strict=True))
vertex_indices = \
np.vectorize(vert_ndx_old2new.__getitem__)(vertex_indices)

Expand Down Expand Up @@ -872,7 +874,10 @@ def export_mesh_to_firedrake(mesh, group_nr=None, comm=None):
group = mesh.groups[group_nr]
fd2mm_indices = np.unique(group.vertex_indices.flatten())
coords = mesh.vertices[:, fd2mm_indices].T
mm2fd_indices = dict(zip(fd2mm_indices, np.arange(np.size(fd2mm_indices))))
mm2fd_indices = dict(zip(
fd2mm_indices,
np.arange(np.size(fd2mm_indices)),
strict=True))
cells = np.vectorize(mm2fd_indices.__getitem__)(group.vertex_indices)

# Get a dmplex object and then a mesh topology
Expand Down
4 changes: 2 additions & 2 deletions meshmode/mesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ def _compute_nodal_adjacency_from_vertices(mesh: Mesh) -> NodalAdjacency:
_, nvertices = mesh.vertices.shape
vertex_to_element: list[list[int]] = [[] for i in range(nvertices)]

for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups):
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups, strict=True):
if grp.vertex_indices is None:
raise ValueError("unable to compute nodal adjacency without vertices")

Expand All @@ -1571,7 +1571,7 @@ def _compute_nodal_adjacency_from_vertices(mesh: Mesh) -> NodalAdjacency:
vertex_to_element[ivertex].append(base_element_nr + iel_grp)

element_to_element: list[set[int]] = [set() for i in range(mesh.nelements)]
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups):
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups, strict=True):
assert grp.vertex_indices is not None

for iel_grp in range(grp.nelements):
Expand Down
4 changes: 2 additions & 2 deletions meshmode/mesh/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,7 @@ def generate_regular_rect_mesh(
"lower topological dimension and map it.)")

axis_coords = [np.linspace(a_i, b_i, npoints_i)
for a_i, b_i, npoints_i in zip(a, b, npoints_per_axis)]
for a_i, b_i, npoints_i in zip(a, b, npoints_per_axis, strict=True)]

return generate_box_mesh(axis_coords, order=order,
periodic=periodic,
Expand Down Expand Up @@ -1654,7 +1654,7 @@ def warp_and_refine_until_resolved(
"(NaN or Inf)")

for base_element_nr, egrp in zip(
warped_mesh.base_element_nrs, warped_mesh.groups):
warped_mesh.base_element_nrs, warped_mesh.groups, strict=True):
if not isinstance(egrp, SimplexElementGroup):
raise TypeError(
f"Unsupported element group type: '{type(egrp).__name__}'")
Expand Down
3 changes: 2 additions & 1 deletion meshmode/mesh/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def get_mesh(self, return_tag_to_elements_map=False):

for el_vertices, el_nodes, el_type, el_markers in zip(
self.element_vertices, self.element_nodes, self.element_types,
self.element_markers):
self.element_markers,
strict=True):
if el_type is not group_el_type:
continue

Expand Down
14 changes: 7 additions & 7 deletions meshmode/mesh/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def unpack_single(ary: Optional[np.ndarray]) -> np.ndarray:

from pymbolic.geometric_algebra import MultiVector

mvs = [MultiVector(vec) for vec in spanning_object_array]
mvs: list[MultiVector] = [MultiVector(vec) for vec in spanning_object_array]

from operator import xor
outer_prod = -reduce(xor, mvs) # pylint: disable=invalid-unary-operand-type
Expand All @@ -660,7 +660,7 @@ def find_volume_mesh_element_orientations(

result: np.ndarray = np.empty(mesh.nelements, dtype=np.float64)

for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups):
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups, strict=True):
result_grp_view = result[base_element_nr:base_element_nr + grp.nelements]

try:
Expand Down Expand Up @@ -833,7 +833,7 @@ def perform_flips(
flip_flags = flip_flags.astype(bool)

new_groups = []
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups):
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups, strict=True):
grp_flip_flags = flip_flags[base_element_nr:base_element_nr + grp.nelements]

if grp_flip_flags.any():
Expand Down Expand Up @@ -928,7 +928,7 @@ def merge_disjoint_meshes(

group_vertex_indices = []
group_nodes = []
for mesh, vert_base in zip(meshes, vert_bases):
for mesh, vert_base in zip(meshes, vert_bases, strict=True):
for group in mesh.groups:
assert group.vertex_indices is not None
group_vertex_indices.append(group.vertex_indices + vert_base)
Expand All @@ -945,7 +945,7 @@ def merge_disjoint_meshes(

else:
new_groups = []
for mesh, vert_base in zip(meshes, vert_bases):
for mesh, vert_base in zip(meshes, vert_bases, strict=True):
for group in mesh.groups:
assert group.vertex_indices is not None
new_vertex_indices = group.vertex_indices + vert_base
Expand Down Expand Up @@ -999,7 +999,7 @@ def split_mesh_groups(
subgroup_to_group_map = {}

for igrp, (base_element_nr, grp) in enumerate(
zip(mesh.base_element_nrs, mesh.groups)
zip(mesh.base_element_nrs, mesh.groups, strict=True)
):
assert grp.vertex_indices is not None
grp_flags = element_flags[base_element_nr:base_element_nr + grp.nelements]
Expand Down Expand Up @@ -1615,7 +1615,7 @@ def make_mesh_grid(
meshes = []

for index in product(*(range(n) for n in shape)):
b = sum((i * o for i, o in zip(index, offset)), offset[0])
b = sum((i * o for i, o in zip(index, offset, strict=True)), offset[0])
meshes.append(affine_map(mesh, b=b))

return merge_disjoint_meshes(meshes, skip_tests=skip_tests)
Expand Down
7 changes: 5 additions & 2 deletions meshmode/mesh/refinement/no_adjacency.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def refine(self, refine_flags):
get_group_tessellation_info,
)

for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups):
for base_element_nr, grp in zip(
mesh.base_element_nrs, mesh.groups, strict=True):
el_tess_info = get_group_tessellation_info(grp)

# {{{ compute counts and index arrays
Expand Down Expand Up @@ -153,7 +154,9 @@ def refine(self, refine_flags):

for imidpoint, (iref_midpoint, (v1, v2)) in enumerate(zip(
el_tess_info.midpoint_indices,
el_tess_info.midpoint_vertex_pairs)):
el_tess_info.midpoint_vertex_pairs,
strict=True
)):

global_v1 = grp.vertex_indices[old_iel, v1]
global_v2 = grp.vertex_indices[old_iel, v2]
Expand Down
Loading