Skip to content

Commit

Permalink
add better corematrix reader, set host_row and top_host_row to itself…
Browse files Browse the repository at this point in the history
… if top host
  • Loading branch information
michaelbuehlmann committed Sep 7, 2023
1 parent ca0e175 commit 0831bf4
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions haccytrees/coretrees/coretree_reader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Mapping
from typing import Mapping, Union, List

import h5py
import numba
import numpy as np

from ..simulations import Simulation

# These fields will always be loaded from the HDF5 files
_essential_fields = ["core_tag", "host_core", "snapnum", "central", "merged"]


@numba.jit(nopython=True)
def _count_coreforest_rows(core_tag):
Expand Down Expand Up @@ -37,12 +41,12 @@ def _get_top_host_row(host_row, top_host_row):
if host_row[i] < 0:
continue
_top_host_row = host_row[i]
while host_row[_top_host_row] >= 0:
while host_row[_top_host_row] >= 0 and host_row[_top_host_row] != _top_host_row:
_top_host_row = host_row[_top_host_row]
top_host_row[i] = _top_host_row


def coreforest_matrix(forest: Mapping[str, np.ndarray], simulation: Simulation):
def coreforest2matrix(forest: Mapping[str, np.ndarray], simulation: Simulation):
# first pass: count rows
nrows = _count_coreforest_rows(forest["core_tag"])
ncols = len(simulation.cosmotools_steps)
Expand Down Expand Up @@ -95,3 +99,30 @@ def coreforest_matrix(forest: Mapping[str, np.ndarray], simulation: Simulation):
forest_matrices["core_state"] = _state

return forest_matrices


def corematrix_reader(
filename: str, simulation: Union[Simulation, str], include_fields: List[str] = None
):
if isinstance(simulation, str):
if simulation[:-4] == ".cfg":
simulation = Simulation.parse_config(simulation)
else:
simulation = Simulation.simulations[simulation]

with h5py.File(filename) as forest_file:
if include_fields is None:
include_fields = list(forest_file["data"].keys())
else:
for k in _essential_fields:
if k not in include_fields:
include_fields.append(k)
forest_data = {k: forest_file["data"][k][:] for k in include_fields}

# set host_core to itself for centrals
forest_data["host_core"][forest_data["central"] == 1] = forest_data["core_tag"][
forest_data["central"] == 1
]

forest_matrices = coreforest2matrix(forest_data, simulation)
return forest_matrices

0 comments on commit 0831bf4

Please sign in to comment.