From 7ddd77fc94ff29382fa188edc6181f49163f4d5f Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Tue, 8 Oct 2024 17:25:59 +0100 Subject: [PATCH] Split catalogue and index masks in multi-galaxy mode. --- swiftgalaxy/halo_catalogues.py | 123 +++++++++++++++++-------------- tests/test_halo_catalogues.py | 8 +-- tests/test_iterator.py | 127 ++++++++++++++++++++++++++++++++- 3 files changed, 200 insertions(+), 58 deletions(-) diff --git a/swiftgalaxy/halo_catalogues.py b/swiftgalaxy/halo_catalogues.py index 3be97a6..1395441 100644 --- a/swiftgalaxy/halo_catalogues.py +++ b/swiftgalaxy/halo_catalogues.py @@ -117,7 +117,8 @@ class _HaloCatalogue(ABC): _user_spatial_offsets: Optional[cosmo_array] = None _multi_galaxy: bool = False - _multi_galaxy_mask_index: Optional[int] = None + _multi_galaxy_catalogue_mask: Optional[int] = None + _multi_galaxy_index_mask: Optional[int] = None _multi_count: int _index_attr: Optional[str] @@ -141,6 +142,13 @@ def _mask_multi_galaxy(self, index) -> None: iterating over many :class:`~swiftgalaxy.reader.SWIFTGalaxy`'s) to restrict the catalogue to a single row, for use during one such iteration. + In most cases the mask into the list of indices (targets provided by user) + and the mask into the catalogue (from the relevant halo catalogue) are the + same, but SOAP for example implicitly sorts the target list when defining + the swiftsimio mask for the halo catalogue, so the two indices differ in + that case. The SOAP class therefore implements its own _mask_multi_galaxy + function that overrides this one. + Parameters ---------- index : :obj:`int` @@ -150,7 +158,8 @@ def _mask_multi_galaxy(self, index) -> None: -------- swiftgalaxy.halo_catalogues._HaloCatalogue._unmask_multi_galaxy """ - self._multi_galaxy_mask_index = index + self._multi_galaxy_catalogue_mask = index + self._multi_galaxy_index_mask = index return def _unmask_multi_galaxy(self) -> None: @@ -165,7 +174,8 @@ def _unmask_multi_galaxy(self) -> None: -------- swiftgalaxy.halo_catalogues._HaloCatalogue._mask_multi_galaxy """ - self._multi_galaxy_mask_index = None + self._multi_galaxy_catalogue_mask = None + self._multi_galaxy_index_mask = None return def _get_user_spatial_mask(self, snapshot_filename: str) -> SWIFTMask: @@ -220,7 +230,7 @@ def _get_spatial_mask(self, snapshot_filename: str) -> SWIFTMask: out : :class:`~swiftsimio.masks.SWIFTMask` The spatial mask to select particles in the region of interest. """ - if self._multi_galaxy and self._multi_galaxy_mask_index is None: + if self._multi_galaxy and self._multi_galaxy_catalogue_mask is None: raise RuntimeError( "Halo catalogue has multiple galaxies and is not currently masked." ) @@ -246,7 +256,7 @@ def _get_extra_mask(self, sg: "SWIFTGalaxy") -> MaskCollection: The extra mask to be applied after spatial masking. """ if self.extra_mask == "bound_only": - if self._multi_galaxy and self._multi_galaxy_mask_index is None: + if self._multi_galaxy and self._multi_galaxy_catalogue_mask is None: raise RuntimeError( "Halo catalogue has multiple galaxies and is not currently masked." ) @@ -282,8 +292,8 @@ def _mask_index(self) -> Optional[Union[int, list[int]]]: else: index = getattr(self, self._index_attr) return ( - index[self._multi_galaxy_mask_index] - if self._multi_galaxy_mask_index is not None + index[self._multi_galaxy_index_mask] + if self._multi_galaxy_index_mask is not None else index ) @@ -340,8 +350,8 @@ def __getattr__(self, attr: str) -> Any: except AttributeError: return None obj = getattr(self._catalogue, attr) - if self._multi_galaxy_mask_index is not None: - return _MaskHelper(obj, self._multi_galaxy_mask_index) + if self._multi_galaxy_catalogue_mask is not None: + return _MaskHelper(obj, self._multi_galaxy_catalogue_mask) else: return obj @@ -357,7 +367,7 @@ def count(self) -> int: out : :obj:`int` The number of galaxies in the catalogue. """ - if self._multi_galaxy and self._multi_galaxy_mask_index is None: + if self._multi_galaxy and self._multi_galaxy_catalogue_mask is None: return self._multi_count else: return 1 @@ -681,7 +691,10 @@ def _mask_multi_galaxy(self, index) -> None: -------- swiftgalaxy.halo_catalogues._HaloCatalogue._unmask_multi_galaxy """ - self._multi_galaxy_mask_index = np.argsort(self._soap_index)[index] + self._multi_galaxy_catalogue_mask = np.argsort(np.argsort(self._soap_index))[ + index + ] + self._multi_galaxy_index_mask = index return @property @@ -821,8 +834,8 @@ def centre(self) -> cosmo_array: obj = self._catalogue for attr in self.centre_type.split("."): obj = getattr(obj, attr) - if self._multi_galaxy_mask_index is not None: - return obj[self._multi_galaxy_mask_index] + if self._multi_galaxy_catalogue_mask is not None: + return obj[self._multi_galaxy_catalogue_mask] return obj.squeeze() @property @@ -842,8 +855,8 @@ def velocity_centre(self) -> cosmo_array: obj = self._catalogue for attr in self.velocity_centre_type.split("."): obj = getattr(obj, attr) - if self._multi_galaxy_mask_index is not None: - return obj[self._multi_galaxy_mask_index] + if self._multi_galaxy_catalogue_mask is not None: + return obj[self._multi_galaxy_catalogue_mask] return obj.squeeze() def __repr__(self) -> str: @@ -1081,11 +1094,11 @@ def _generate_spatial_mask(self, snapshot_filename: str) -> SWIFTMask: from velociraptor.swift.swift import generate_spatial_mask # super()._get_spatial_mask guards getting here in multi-galaxy mode - # without a self._multi_galaxy_mask_index + # without a self._multi_galaxy_catalogue_mask return generate_spatial_mask( ( - self._particles[self._multi_galaxy_mask_index] - if self._multi_galaxy_mask_index is not None + self._particles[self._multi_galaxy_catalogue_mask] + if self._multi_galaxy_catalogue_mask is not None else self._particles ), snapshot_filename, @@ -1119,8 +1132,8 @@ def _generate_bound_only_mask(self, sg: "SWIFTGalaxy") -> MaskCollection: **generate_bound_mask( sg, ( - self._particles[self._multi_galaxy_mask_index] - if self._multi_galaxy_mask_index is not None + self._particles[self._multi_galaxy_catalogue_mask] + if self._multi_galaxy_catalogue_mask is not None else self._particles ), )._asdict() @@ -1164,7 +1177,7 @@ def _region_centre(self) -> cosmo_array: if not self._particles[0].groups_instance.catalogue.units.comoving else 1.0 ) - if self._multi_galaxy_mask_index is None: + if self._multi_galaxy_catalogue_mask is None: return cosmo_array( [ [ @@ -1181,11 +1194,11 @@ def _region_centre(self) -> cosmo_array: else: return cosmo_array( [ - self._particles[self._multi_galaxy_mask_index].x.to_value(u.Mpc) + self._particles[self._multi_galaxy_catalogue_mask].x.to_value(u.Mpc) / length_factor, - self._particles[self._multi_galaxy_mask_index].y.to_value(u.Mpc) + self._particles[self._multi_galaxy_catalogue_mask].y.to_value(u.Mpc) / length_factor, - self._particles[self._multi_galaxy_mask_index].z.to_value(u.Mpc) + self._particles[self._multi_galaxy_catalogue_mask].z.to_value(u.Mpc) / length_factor, ], u.Mpc, @@ -1213,7 +1226,7 @@ def _region_aperture(self) -> cosmo_array: if not self._particles[0].groups_instance.catalogue.units.comoving else 1.0 ) - if self._multi_galaxy_mask_index is None: + if self._multi_galaxy_catalogue_mask is None: return cosmo_array( [ particles.r_size.to_value(u.Mpc) / length_factor @@ -1225,7 +1238,9 @@ def _region_aperture(self) -> cosmo_array: ).squeeze() else: return cosmo_array( - self._particles[self._multi_galaxy_mask_index].r_size.to_value(u.Mpc) + self._particles[self._multi_galaxy_catalogue_mask].r_size.to_value( + u.Mpc + ) / length_factor, u.Mpc, comoving=True, @@ -1304,10 +1319,10 @@ def centre(self) -> cosmo_array: comoving=False, # velociraptor gives physical centres! cosmo_factor=cosmo_factor(a**1, self.scale_factor), ).to_comoving() - if self._multi_galaxy and self._multi_galaxy_mask_index is None: + if self._multi_galaxy and self._multi_galaxy_catalogue_mask is None: return centre - elif self._multi_galaxy and self._multi_galaxy_mask_index is not None: - return centre[self._multi_galaxy_mask_index] + elif self._multi_galaxy and self._multi_galaxy_catalogue_mask is not None: + return centre[self._multi_galaxy_catalogue_mask] else: return centre.squeeze() @@ -1356,10 +1371,10 @@ def velocity_centre(self) -> cosmo_array: comoving=False, cosmo_factor=cosmo_factor(a**0, self.scale_factor), ).to_comoving() - if self._multi_galaxy and self._multi_galaxy_mask_index is None: + if self._multi_galaxy and self._multi_galaxy_catalogue_mask is None: return vcentre - elif self._multi_galaxy and self._multi_galaxy_mask_index is not None: - return vcentre[self._multi_galaxy_mask_index] + elif self._multi_galaxy and self._multi_galaxy_catalogue_mask is not None: + return vcentre[self._multi_galaxy_catalogue_mask] else: return vcentre.squeeze() @@ -1734,7 +1749,7 @@ def _region_centre(self) -> cosmo_array: """ cats = [self._catalogue] if not self._multi_galaxy else self._catalogue assert isinstance(cats, List) # placate mypy - if self._multi_galaxy_mask_index is None: + if self._multi_galaxy_catalogue_mask is None: return cosmo_array( [cat.pos.to(u.kpc) for cat in cats], # maybe comoving, ensure physical comoving=False, @@ -1742,7 +1757,7 @@ def _region_centre(self) -> cosmo_array: ).to_comoving() else: return cosmo_array( - cats[self._multi_galaxy_mask_index].pos.to( + cats[self._multi_galaxy_catalogue_mask].pos.to( u.kpc ), # maybe comoving, ensure physical comoving=False, @@ -1770,7 +1785,7 @@ def _region_aperture(self) -> cosmo_array: assert isinstance(cats, List) # placate mypy if "total_rmax" in cats[0].radii.keys(): # spatial extent information is present - if self._multi_galaxy_mask_index is None: + if self._multi_galaxy_catalogue_mask is None: return cosmo_array( [ cat.radii["total_rmax"].to(u.kpc) for cat in cats @@ -1782,7 +1797,7 @@ def _region_aperture(self) -> cosmo_array: ).to_comoving() else: return cosmo_array( - cats[self._multi_galaxy_mask_index] + cats[self._multi_galaxy_catalogue_mask] .radii["total_rmax"] .to(u.kpc), # maybe comoving, ensure physical comoving=False, @@ -1827,9 +1842,9 @@ def _mask_catalogue(self) -> Union["CaesarHalo", "CaesarGalaxy"]: out : :class:`~loader.Halo` or :class:`~loader.Galaxy` The item from the caesar list selected by the mask. """ - if self._multi_galaxy and self._multi_galaxy_mask_index is not None: - cat = self._catalogue[self._multi_galaxy_mask_index] - elif self._multi_galaxy and self._multi_galaxy_mask_index is None: + if self._multi_galaxy and self._multi_galaxy_catalogue_mask is not None: + cat = self._catalogue[self._multi_galaxy_catalogue_mask] + elif self._multi_galaxy and self._multi_galaxy_catalogue_mask is None: raise RuntimeError("Tried to mask catalogue without mask index!") else: cat = self._catalogue @@ -1849,9 +1864,11 @@ def centre(self) -> cosmo_array: The centre(s) of the object(s) of interest. """ centre_attr = {"": "pos", "minpot": "minpotpos"}[self.centre_type] - if self._multi_galaxy_mask_index is not None: + if self._multi_galaxy_catalogue_mask is not None: return cosmo_array( - getattr(self._catalogue[self._multi_galaxy_mask_index], centre_attr).to( + getattr( + self._catalogue[self._multi_galaxy_catalogue_mask], centre_attr + ).to( u.kpc ), # maybe comoving, ensure physical comoving=False, @@ -1881,10 +1898,10 @@ def velocity_centre(self) -> cosmo_array: The velocity centre provided by the halo catalogue. """ vcentre_attr = {"": "vel", "minpot": "minpotvel"}[self.centre_type] - if self._multi_galaxy_mask_index is not None: + if self._multi_galaxy_catalogue_mask is not None: return cosmo_array( getattr( - self._catalogue[self._multi_galaxy_mask_index], vcentre_attr + self._catalogue[self._multi_galaxy_catalogue_mask], vcentre_attr ).to(u.km / u.s), comoving=False, cosmo_factor=cosmo_factor(a**0, self._caesar.simulation.scale_factor), @@ -1926,9 +1943,9 @@ def __getattr__(self, attr: str) -> Any: return object.__getattribute__(self, "_catalogue") except AttributeError: return None - if self._multi_galaxy_mask_index is not None: - return getattr(self._catalogue[self._multi_galaxy_mask_index], attr) - elif self._multi_galaxy and self._multi_galaxy_mask_index is None: + if self._multi_galaxy_catalogue_mask is not None: + return getattr(self._catalogue[self._multi_galaxy_catalogue_mask], attr) + elif self._multi_galaxy and self._multi_galaxy_catalogue_mask is None: return [getattr(cat, attr) for cat in self._catalogue] else: return getattr(self._catalogue, attr) @@ -2150,10 +2167,10 @@ def _region_centre(self) -> cosmo_array: out : :class:`~swiftsimio.objects.cosmo_array` The coordinates of the centres of the spatial mask regions. """ - if self._multi_galaxy_mask_index is None: + if self._multi_galaxy_index_mask is None: return self._centre else: - return self._centre[self._multi_galaxy_mask_index] + return self._centre[self._multi_galaxy_index_mask] @property def _region_aperture(self) -> cosmo_array: @@ -2171,7 +2188,7 @@ def _region_aperture(self) -> cosmo_array: regions. """ if self._user_spatial_offsets is not None: - if self._multi_galaxy_mask_index is None: + if self._multi_galaxy_index_mask is None: return np.repeat( np.max(np.abs(self._user_spatial_offsets)), len(self._centre) ) @@ -2213,8 +2230,8 @@ def centre(self) -> cosmo_array: out : :class:`~swiftsimio.objects.cosmo_array` The centre(s) of the object(s) of interest. """ - if self._multi_galaxy_mask_index is not None: - return self._centre[self._multi_galaxy_mask_index] + if self._multi_galaxy_index_mask is not None: + return self._centre[self._multi_galaxy_index_mask] return self._centre @property @@ -2230,6 +2247,6 @@ def velocity_centre(self) -> cosmo_array: out : :class:`~swiftsimio.objects.cosmo_array` The velocity coordinate origin. """ - if self._multi_galaxy_mask_index is not None: - return self._velocity_centre[self._multi_galaxy_mask_index] + if self._multi_galaxy_index_mask is not None: + return self._velocity_centre[self._multi_galaxy_index_mask] return self._velocity_centre diff --git a/tests/test_halo_catalogues.py b/tests/test_halo_catalogues.py index 157dd2d..bbda788 100644 --- a/tests/test_halo_catalogues.py +++ b/tests/test_halo_catalogues.py @@ -232,7 +232,7 @@ def test_multi_flags(self, hf_multi): Check that the multi-target nature of the cataloue is recognized. """ assert hf_multi._multi_galaxy - assert hf_multi._multi_galaxy_mask_index is None + assert hf_multi._multi_galaxy_catalogue_mask is None assert hf_multi._multi_count == 2 assert hf_multi.count == 2 @@ -240,7 +240,7 @@ def test_mask_multi_galaxy(self, hf_multi): """ Check that we can mask the catalogue to focus on one object, and unmask. """ - assert hf_multi._multi_galaxy_mask_index is None + assert hf_multi._multi_galaxy_catalogue_mask is None assert hf_multi.count > 1 assert hf_multi._region_centre.shape == (hf_multi.count, 3) assert hf_multi._region_aperture.shape == (hf_multi.count,) @@ -248,14 +248,14 @@ def test_mask_multi_galaxy(self, hf_multi): assert hf_multi.velocity_centre.shape == (hf_multi.count, 3) mask_index = 0 hf_multi._mask_multi_galaxy(mask_index) - assert hf_multi._multi_galaxy_mask_index == mask_index + assert hf_multi._multi_galaxy_catalogue_mask == mask_index assert hf_multi.count == 1 assert hf_multi._region_centre.shape == (3,) assert hf_multi._region_aperture.shape == tuple() assert hf_multi.centre.shape == (3,) assert hf_multi.velocity_centre.shape == (3,) hf_multi._unmask_multi_galaxy() - assert hf_multi._multi_galaxy_mask_index is None + assert hf_multi._multi_galaxy_catalogue_mask is None assert hf_multi.count > 1 assert hf_multi._region_centre.shape == (hf_multi.count, 3) assert hf_multi._region_aperture.shape == (hf_multi.count,) diff --git a/tests/test_iterator.py b/tests/test_iterator.py index 7fde38e..cc78e17 100644 --- a/tests/test_iterator.py +++ b/tests/test_iterator.py @@ -18,6 +18,10 @@ class TestSWIFTGalaxies: def test_eval_sparse_optimized_solution(self, toysnap): + """ + Check that the sparse solution is chosen when optimal and matches expectations + for a case that we can work out by hand. + """ # place a single target in the centre of each cell # this should make sparse iteration optimal # at a cost of 2 cell reads @@ -86,6 +90,10 @@ def test_eval_sparse_optimized_solution(self, toysnap): ) def test_eval_dense_optimized_solution(self, toysnap): + """ + Check that the dense solution is chosen when optimal and matches expectations + for a case that we can work out by hand. + """ # Place a single target in the centre of each cell # and ones straddling many faces, vertices and corners # this should make dense iteration optimal @@ -183,6 +191,10 @@ def test_eval_dense_optimized_solution(self, toysnap): ) def test_iteration_order(self, sgs): + """ + Check that the iteration order agrees with that computed in the two iteration + optimizations. + """ # force dense solution sgs._solution = sgs._dense_optimized_solution assert np.allclose( @@ -199,6 +211,10 @@ def test_iteration_order(self, sgs): ) def test_iterate(self, sgs): + """ + Check that we iterate over the right number of SWIFTGalaxy objects and that + they behave like a SWIFTGalaxy created on its own for each target. + """ count = 0 for sg_from_sgs in sgs: sg = SWIFTGalaxy( @@ -217,6 +233,9 @@ def test_iterate(self, sgs): @pytest.mark.parametrize("extra_mask", ["bound_only", None]) def test_preload(self, toysnap_withfof, hf_multi, extra_mask): + """ + Make sure that data that we ask to have pre-loaded is actually pre-loaded. + """ hf_multi.extra_mask = extra_mask sgs = SWIFTGalaxies( ( @@ -225,7 +244,7 @@ def test_preload(self, toysnap_withfof, hf_multi, extra_mask): else toysnap_filename ), hf_multi, - preload={ # just to keep warnings quiet + preload={ "gas.particle_ids", "dark_matter.particle_ids", "stars.particle_ids", @@ -243,7 +262,45 @@ def test_preload(self, toysnap_withfof, hf_multi, extra_mask): is not None ) + def test_warn_on_no_preload(self, toysnap): + """ + Check that we warn users if they don't specify anything to pre-load since this + probably indicates that they're using the SWIFTGalaxies class inefficiently. + """ + with pytest.warns(RuntimeWarning, match="No data specified to preload"): + SWIFTGalaxies( + toysnap_filename, + ToyHF(index=[0, 1]), + ) + + def test_warn_on_read_not_preloaded(self, sgs): + """ + Check that we warn users when data is loaded while iterating over a SWIFTGalaxies + since this probably indicates that they're using the class inefficiently. + """ + assert "gas.coordinates" not in sgs.preload + for sg in sgs: + with pytest.warns(RuntimeWarning, match="should it be preloaded"): + sg.gas.coordinates + + def test_exception_on_repeated_targets(self, toysnap): + """ + Due to especially swiftsimio's masking behaviour having duplicate targets in the + list for a SWIFTGalaxies causes all kinds of problems, so make sure we raise an + exception if a user tries to do this. + """ + with pytest.raises(ValueError, match="must not contain duplicates"): + SWIFTGalaxies( + toysnap_filename, + ToyHF(index=[0, 0]), + ) + def test_map(self, toysnap_withfof, hf_multi): + """ + Check that the map method returns results in the same order as the input target + list. We're careful in this test to make sure that the iteration order is + different from the input list order. + """ sgs = SWIFTGalaxies( ( toysoap_virtual_snapshot_filename @@ -287,3 +344,71 @@ def f(sg): assert sgs.map(f) == getattr( sgs.halo_catalogue, sgs.halo_catalogue._index_attr[1:] ) + + def test_arbitrary_index_ordering(self, toysnap_withfof, hf_multi): + """ + Check that SWIFTGalaxies gives consistent results for any order of target objects. + + Especially important for velociraptor where some logic had to be added to avoid + hdf5 complaining about an unsorted list of indices to read from file. + """ + + def f(sg): + if isinstance(hf_multi, Standalone): + return int( + np.argwhere( + np.all( + sg.halo_catalogue.centre == sg.halo_catalogue._centre, + axis=1, + ) + ).squeeze() + ) + # _index_attr has leading underscore, access through property with [1:] + return getattr(sg.halo_catalogue, sg.halo_catalogue._index_attr[1:]) + + sgs_forwards = SWIFTGalaxies( + ( + toysoap_virtual_snapshot_filename + if isinstance(hf_multi, SOAP) + else toysnap_filename + ), + hf_multi, + preload={ # just to keep warnings quiet + "gas.particle_ids", + "dark_matter.particle_ids", + "stars.particle_ids", + "black_holes.particle_ids", + }, + ) + map_forwards = sgs_forwards.map(f) + if isinstance(hf_multi, Standalone): + hf_multi._centre = hf_multi._centre[::-1] + hf_multi._velocity_centre = hf_multi._velocity_centre[::-1] + else: + setattr( + hf_multi, + hf_multi._index_attr, + getattr(hf_multi, hf_multi._index_attr)[::-1], + ) + sgs_backwards = SWIFTGalaxies( + ( + toysoap_virtual_snapshot_filename + if isinstance(hf_multi, SOAP) + else toysnap_filename + ), + hf_multi, + preload={ # just to keep warnings quiet + "gas.particle_ids", + "dark_matter.particle_ids", + "stars.particle_ids", + "black_holes.particle_ids", + }, + ) + map_backwards = sgs_backwards.map(f) + if isinstance(hf_multi, Standalone): + # because the reversed catalogue compares to the reversed catalogue in f, + # we get the reverse-of-the-reverse in map_backwards, which should match + # map_forwards + assert map_forwards == map_backwards + return + assert map_forwards == map_backwards[::-1]