Skip to content

Commit

Permalink
Split catalogue and index masks in multi-galaxy mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleaoman committed Oct 8, 2024
1 parent e148dc2 commit 7ddd77f
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 58 deletions.
123 changes: 70 additions & 53 deletions swiftgalaxy/halo_catalogues.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ class _HaloCatalogue(ABC):

_user_spatial_offsets: Optional[cosmo_array] = None
_multi_galaxy: bool = False
_multi_galaxy_mask_index: Optional[int] = None
_multi_galaxy_catalogue_mask: Optional[int] = None
_multi_galaxy_index_mask: Optional[int] = None
_multi_count: int
_index_attr: Optional[str]

Expand All @@ -141,6 +142,13 @@ def _mask_multi_galaxy(self, index) -> None:
iterating over many :class:`~swiftgalaxy.reader.SWIFTGalaxy`'s) to
restrict the catalogue to a single row, for use during one such iteration.
In most cases the mask into the list of indices (targets provided by user)
and the mask into the catalogue (from the relevant halo catalogue) are the
same, but SOAP for example implicitly sorts the target list when defining
the swiftsimio mask for the halo catalogue, so the two indices differ in
that case. The SOAP class therefore implements its own _mask_multi_galaxy
function that overrides this one.
Parameters
----------
index : :obj:`int`
Expand All @@ -150,7 +158,8 @@ def _mask_multi_galaxy(self, index) -> None:
--------
swiftgalaxy.halo_catalogues._HaloCatalogue._unmask_multi_galaxy
"""
self._multi_galaxy_mask_index = index
self._multi_galaxy_catalogue_mask = index
self._multi_galaxy_index_mask = index
return

def _unmask_multi_galaxy(self) -> None:
Expand All @@ -165,7 +174,8 @@ def _unmask_multi_galaxy(self) -> None:
--------
swiftgalaxy.halo_catalogues._HaloCatalogue._mask_multi_galaxy
"""
self._multi_galaxy_mask_index = None
self._multi_galaxy_catalogue_mask = None
self._multi_galaxy_index_mask = None
return

def _get_user_spatial_mask(self, snapshot_filename: str) -> SWIFTMask:
Expand Down Expand Up @@ -220,7 +230,7 @@ def _get_spatial_mask(self, snapshot_filename: str) -> SWIFTMask:
out : :class:`~swiftsimio.masks.SWIFTMask`
The spatial mask to select particles in the region of interest.
"""
if self._multi_galaxy and self._multi_galaxy_mask_index is None:
if self._multi_galaxy and self._multi_galaxy_catalogue_mask is None:
raise RuntimeError(
"Halo catalogue has multiple galaxies and is not currently masked."
)
Expand All @@ -246,7 +256,7 @@ def _get_extra_mask(self, sg: "SWIFTGalaxy") -> MaskCollection:
The extra mask to be applied after spatial masking.
"""
if self.extra_mask == "bound_only":
if self._multi_galaxy and self._multi_galaxy_mask_index is None:
if self._multi_galaxy and self._multi_galaxy_catalogue_mask is None:
raise RuntimeError(
"Halo catalogue has multiple galaxies and is not currently masked."
)
Expand Down Expand Up @@ -282,8 +292,8 @@ def _mask_index(self) -> Optional[Union[int, list[int]]]:
else:
index = getattr(self, self._index_attr)
return (
index[self._multi_galaxy_mask_index]
if self._multi_galaxy_mask_index is not None
index[self._multi_galaxy_index_mask]
if self._multi_galaxy_index_mask is not None
else index
)

Expand Down Expand Up @@ -340,8 +350,8 @@ def __getattr__(self, attr: str) -> Any:
except AttributeError:
return None
obj = getattr(self._catalogue, attr)
if self._multi_galaxy_mask_index is not None:
return _MaskHelper(obj, self._multi_galaxy_mask_index)
if self._multi_galaxy_catalogue_mask is not None:
return _MaskHelper(obj, self._multi_galaxy_catalogue_mask)
else:
return obj

Expand All @@ -357,7 +367,7 @@ def count(self) -> int:
out : :obj:`int`
The number of galaxies in the catalogue.
"""
if self._multi_galaxy and self._multi_galaxy_mask_index is None:
if self._multi_galaxy and self._multi_galaxy_catalogue_mask is None:
return self._multi_count
else:
return 1
Expand Down Expand Up @@ -681,7 +691,10 @@ def _mask_multi_galaxy(self, index) -> None:
--------
swiftgalaxy.halo_catalogues._HaloCatalogue._unmask_multi_galaxy
"""
self._multi_galaxy_mask_index = np.argsort(self._soap_index)[index]
self._multi_galaxy_catalogue_mask = np.argsort(np.argsort(self._soap_index))[
index
]
self._multi_galaxy_index_mask = index
return

@property
Expand Down Expand Up @@ -821,8 +834,8 @@ def centre(self) -> cosmo_array:
obj = self._catalogue
for attr in self.centre_type.split("."):
obj = getattr(obj, attr)
if self._multi_galaxy_mask_index is not None:
return obj[self._multi_galaxy_mask_index]
if self._multi_galaxy_catalogue_mask is not None:
return obj[self._multi_galaxy_catalogue_mask]
return obj.squeeze()

@property
Expand All @@ -842,8 +855,8 @@ def velocity_centre(self) -> cosmo_array:
obj = self._catalogue
for attr in self.velocity_centre_type.split("."):
obj = getattr(obj, attr)
if self._multi_galaxy_mask_index is not None:
return obj[self._multi_galaxy_mask_index]
if self._multi_galaxy_catalogue_mask is not None:
return obj[self._multi_galaxy_catalogue_mask]
return obj.squeeze()

def __repr__(self) -> str:
Expand Down Expand Up @@ -1081,11 +1094,11 @@ def _generate_spatial_mask(self, snapshot_filename: str) -> SWIFTMask:
from velociraptor.swift.swift import generate_spatial_mask

# super()._get_spatial_mask guards getting here in multi-galaxy mode
# without a self._multi_galaxy_mask_index
# without a self._multi_galaxy_catalogue_mask
return generate_spatial_mask(
(
self._particles[self._multi_galaxy_mask_index]
if self._multi_galaxy_mask_index is not None
self._particles[self._multi_galaxy_catalogue_mask]
if self._multi_galaxy_catalogue_mask is not None
else self._particles
),
snapshot_filename,
Expand Down Expand Up @@ -1119,8 +1132,8 @@ def _generate_bound_only_mask(self, sg: "SWIFTGalaxy") -> MaskCollection:
**generate_bound_mask(
sg,
(
self._particles[self._multi_galaxy_mask_index]
if self._multi_galaxy_mask_index is not None
self._particles[self._multi_galaxy_catalogue_mask]
if self._multi_galaxy_catalogue_mask is not None
else self._particles
),
)._asdict()
Expand Down Expand Up @@ -1164,7 +1177,7 @@ def _region_centre(self) -> cosmo_array:
if not self._particles[0].groups_instance.catalogue.units.comoving
else 1.0
)
if self._multi_galaxy_mask_index is None:
if self._multi_galaxy_catalogue_mask is None:
return cosmo_array(
[
[
Expand All @@ -1181,11 +1194,11 @@ def _region_centre(self) -> cosmo_array:
else:
return cosmo_array(
[
self._particles[self._multi_galaxy_mask_index].x.to_value(u.Mpc)
self._particles[self._multi_galaxy_catalogue_mask].x.to_value(u.Mpc)
/ length_factor,
self._particles[self._multi_galaxy_mask_index].y.to_value(u.Mpc)
self._particles[self._multi_galaxy_catalogue_mask].y.to_value(u.Mpc)
/ length_factor,
self._particles[self._multi_galaxy_mask_index].z.to_value(u.Mpc)
self._particles[self._multi_galaxy_catalogue_mask].z.to_value(u.Mpc)
/ length_factor,
],
u.Mpc,
Expand Down Expand Up @@ -1213,7 +1226,7 @@ def _region_aperture(self) -> cosmo_array:
if not self._particles[0].groups_instance.catalogue.units.comoving
else 1.0
)
if self._multi_galaxy_mask_index is None:
if self._multi_galaxy_catalogue_mask is None:
return cosmo_array(
[
particles.r_size.to_value(u.Mpc) / length_factor
Expand All @@ -1225,7 +1238,9 @@ def _region_aperture(self) -> cosmo_array:
).squeeze()
else:
return cosmo_array(
self._particles[self._multi_galaxy_mask_index].r_size.to_value(u.Mpc)
self._particles[self._multi_galaxy_catalogue_mask].r_size.to_value(
u.Mpc
)
/ length_factor,
u.Mpc,
comoving=True,
Expand Down Expand Up @@ -1304,10 +1319,10 @@ def centre(self) -> cosmo_array:
comoving=False, # velociraptor gives physical centres!
cosmo_factor=cosmo_factor(a**1, self.scale_factor),
).to_comoving()
if self._multi_galaxy and self._multi_galaxy_mask_index is None:
if self._multi_galaxy and self._multi_galaxy_catalogue_mask is None:
return centre
elif self._multi_galaxy and self._multi_galaxy_mask_index is not None:
return centre[self._multi_galaxy_mask_index]
elif self._multi_galaxy and self._multi_galaxy_catalogue_mask is not None:
return centre[self._multi_galaxy_catalogue_mask]
else:
return centre.squeeze()

Expand Down Expand Up @@ -1356,10 +1371,10 @@ def velocity_centre(self) -> cosmo_array:
comoving=False,
cosmo_factor=cosmo_factor(a**0, self.scale_factor),
).to_comoving()
if self._multi_galaxy and self._multi_galaxy_mask_index is None:
if self._multi_galaxy and self._multi_galaxy_catalogue_mask is None:
return vcentre
elif self._multi_galaxy and self._multi_galaxy_mask_index is not None:
return vcentre[self._multi_galaxy_mask_index]
elif self._multi_galaxy and self._multi_galaxy_catalogue_mask is not None:
return vcentre[self._multi_galaxy_catalogue_mask]
else:
return vcentre.squeeze()

Expand Down Expand Up @@ -1734,15 +1749,15 @@ def _region_centre(self) -> cosmo_array:
"""
cats = [self._catalogue] if not self._multi_galaxy else self._catalogue
assert isinstance(cats, List) # placate mypy
if self._multi_galaxy_mask_index is None:
if self._multi_galaxy_catalogue_mask is None:
return cosmo_array(
[cat.pos.to(u.kpc) for cat in cats], # maybe comoving, ensure physical
comoving=False,
cosmo_factor=cosmo_factor(a**1, self._caesar.simulation.scale_factor),
).to_comoving()
else:
return cosmo_array(
cats[self._multi_galaxy_mask_index].pos.to(
cats[self._multi_galaxy_catalogue_mask].pos.to(
u.kpc
), # maybe comoving, ensure physical
comoving=False,
Expand Down Expand Up @@ -1770,7 +1785,7 @@ def _region_aperture(self) -> cosmo_array:
assert isinstance(cats, List) # placate mypy
if "total_rmax" in cats[0].radii.keys():
# spatial extent information is present
if self._multi_galaxy_mask_index is None:
if self._multi_galaxy_catalogue_mask is None:
return cosmo_array(
[
cat.radii["total_rmax"].to(u.kpc) for cat in cats
Expand All @@ -1782,7 +1797,7 @@ def _region_aperture(self) -> cosmo_array:
).to_comoving()
else:
return cosmo_array(
cats[self._multi_galaxy_mask_index]
cats[self._multi_galaxy_catalogue_mask]
.radii["total_rmax"]
.to(u.kpc), # maybe comoving, ensure physical
comoving=False,
Expand Down Expand Up @@ -1827,9 +1842,9 @@ def _mask_catalogue(self) -> Union["CaesarHalo", "CaesarGalaxy"]:
out : :class:`~loader.Halo` or :class:`~loader.Galaxy`
The item from the caesar list selected by the mask.
"""
if self._multi_galaxy and self._multi_galaxy_mask_index is not None:
cat = self._catalogue[self._multi_galaxy_mask_index]
elif self._multi_galaxy and self._multi_galaxy_mask_index is None:
if self._multi_galaxy and self._multi_galaxy_catalogue_mask is not None:
cat = self._catalogue[self._multi_galaxy_catalogue_mask]
elif self._multi_galaxy and self._multi_galaxy_catalogue_mask is None:
raise RuntimeError("Tried to mask catalogue without mask index!")
else:
cat = self._catalogue
Expand All @@ -1849,9 +1864,11 @@ def centre(self) -> cosmo_array:
The centre(s) of the object(s) of interest.
"""
centre_attr = {"": "pos", "minpot": "minpotpos"}[self.centre_type]
if self._multi_galaxy_mask_index is not None:
if self._multi_galaxy_catalogue_mask is not None:
return cosmo_array(
getattr(self._catalogue[self._multi_galaxy_mask_index], centre_attr).to(
getattr(
self._catalogue[self._multi_galaxy_catalogue_mask], centre_attr
).to(
u.kpc
), # maybe comoving, ensure physical
comoving=False,
Expand Down Expand Up @@ -1881,10 +1898,10 @@ def velocity_centre(self) -> cosmo_array:
The velocity centre provided by the halo catalogue.
"""
vcentre_attr = {"": "vel", "minpot": "minpotvel"}[self.centre_type]
if self._multi_galaxy_mask_index is not None:
if self._multi_galaxy_catalogue_mask is not None:
return cosmo_array(
getattr(
self._catalogue[self._multi_galaxy_mask_index], vcentre_attr
self._catalogue[self._multi_galaxy_catalogue_mask], vcentre_attr
).to(u.km / u.s),
comoving=False,
cosmo_factor=cosmo_factor(a**0, self._caesar.simulation.scale_factor),
Expand Down Expand Up @@ -1926,9 +1943,9 @@ def __getattr__(self, attr: str) -> Any:
return object.__getattribute__(self, "_catalogue")
except AttributeError:
return None
if self._multi_galaxy_mask_index is not None:
return getattr(self._catalogue[self._multi_galaxy_mask_index], attr)
elif self._multi_galaxy and self._multi_galaxy_mask_index is None:
if self._multi_galaxy_catalogue_mask is not None:
return getattr(self._catalogue[self._multi_galaxy_catalogue_mask], attr)
elif self._multi_galaxy and self._multi_galaxy_catalogue_mask is None:
return [getattr(cat, attr) for cat in self._catalogue]
else:
return getattr(self._catalogue, attr)
Expand Down Expand Up @@ -2150,10 +2167,10 @@ def _region_centre(self) -> cosmo_array:
out : :class:`~swiftsimio.objects.cosmo_array`
The coordinates of the centres of the spatial mask regions.
"""
if self._multi_galaxy_mask_index is None:
if self._multi_galaxy_index_mask is None:
return self._centre
else:
return self._centre[self._multi_galaxy_mask_index]
return self._centre[self._multi_galaxy_index_mask]

@property
def _region_aperture(self) -> cosmo_array:
Expand All @@ -2171,7 +2188,7 @@ def _region_aperture(self) -> cosmo_array:
regions.
"""
if self._user_spatial_offsets is not None:
if self._multi_galaxy_mask_index is None:
if self._multi_galaxy_index_mask is None:
return np.repeat(
np.max(np.abs(self._user_spatial_offsets)), len(self._centre)
)
Expand Down Expand Up @@ -2213,8 +2230,8 @@ def centre(self) -> cosmo_array:
out : :class:`~swiftsimio.objects.cosmo_array`
The centre(s) of the object(s) of interest.
"""
if self._multi_galaxy_mask_index is not None:
return self._centre[self._multi_galaxy_mask_index]
if self._multi_galaxy_index_mask is not None:
return self._centre[self._multi_galaxy_index_mask]
return self._centre

@property
Expand All @@ -2230,6 +2247,6 @@ def velocity_centre(self) -> cosmo_array:
out : :class:`~swiftsimio.objects.cosmo_array`
The velocity coordinate origin.
"""
if self._multi_galaxy_mask_index is not None:
return self._velocity_centre[self._multi_galaxy_mask_index]
if self._multi_galaxy_index_mask is not None:
return self._velocity_centre[self._multi_galaxy_index_mask]
return self._velocity_centre
8 changes: 4 additions & 4 deletions tests/test_halo_catalogues.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,30 +232,30 @@ def test_multi_flags(self, hf_multi):
Check that the multi-target nature of the cataloue is recognized.
"""
assert hf_multi._multi_galaxy
assert hf_multi._multi_galaxy_mask_index is None
assert hf_multi._multi_galaxy_catalogue_mask is None
assert hf_multi._multi_count == 2
assert hf_multi.count == 2

def test_mask_multi_galaxy(self, hf_multi):
"""
Check that we can mask the catalogue to focus on one object, and unmask.
"""
assert hf_multi._multi_galaxy_mask_index is None
assert hf_multi._multi_galaxy_catalogue_mask is None
assert hf_multi.count > 1
assert hf_multi._region_centre.shape == (hf_multi.count, 3)
assert hf_multi._region_aperture.shape == (hf_multi.count,)
assert hf_multi.centre.shape == (hf_multi.count, 3)
assert hf_multi.velocity_centre.shape == (hf_multi.count, 3)
mask_index = 0
hf_multi._mask_multi_galaxy(mask_index)
assert hf_multi._multi_galaxy_mask_index == mask_index
assert hf_multi._multi_galaxy_catalogue_mask == mask_index
assert hf_multi.count == 1
assert hf_multi._region_centre.shape == (3,)
assert hf_multi._region_aperture.shape == tuple()
assert hf_multi.centre.shape == (3,)
assert hf_multi.velocity_centre.shape == (3,)
hf_multi._unmask_multi_galaxy()
assert hf_multi._multi_galaxy_mask_index is None
assert hf_multi._multi_galaxy_catalogue_mask is None
assert hf_multi.count > 1
assert hf_multi._region_centre.shape == (hf_multi.count, 3)
assert hf_multi._region_aperture.shape == (hf_multi.count,)
Expand Down
Loading

0 comments on commit 7ddd77f

Please sign in to comment.