From e21c7dd4d8ec9cd799afaf30937e36d8767653f2 Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Thu, 16 Nov 2023 21:01:31 +0100 Subject: [PATCH 1/2] mv tof sinogram to lm generation to separate function --- python/00_simulate_lm_data.py | 50 ++------- python/parallelproj_utils.py | 185 ++++++++++++++++++---------------- python/utils.py | 79 +++++++++++++++ 3 files changed, 187 insertions(+), 127 deletions(-) create mode 100644 python/utils.py diff --git a/python/00_simulate_lm_data.py b/python/00_simulate_lm_data.py index 7b0ca1a..b2c71af 100644 --- a/python/00_simulate_lm_data.py +++ b/python/00_simulate_lm_data.py @@ -6,6 +6,7 @@ import prd import parallelproj import parallelproj_utils +import utils import array_api_compat.numpy as np import matplotlib.pyplot as plt from array_api_compat import to_device @@ -84,6 +85,9 @@ ) projector.tof = True # set this to True to get a time of flight projector +# repeat number of TOF bin times here +num_tof_bins = projector.tof_parameters.num_tofbins + # forward project the image noise_free_sinogram = projector(img) @@ -101,50 +105,14 @@ noisy_sinogram = xp.asarray( np.random.poisson(np.asarray(to_device(noise_free_sinogram, "cpu"))), device=dev ) -# ravel the noisy sinogram and the detector start and end "index" sinograms -noisy_sinogram = xp.reshape(noisy_sinogram, (noisy_sinogram.size,)) - -# get the two dimensional indices of all sinogram bins -start_mods, end_mods, start_inds, end_inds = lor_descriptor.get_lor_indices() - -# generate two sinograms that contain the linearized detector start and end indices -sino_det_start_index = ( - lor_descriptor.scanner.num_lor_endpoints_per_module[0] * start_mods + start_inds -) -sino_det_end_index = ( - lor_descriptor.scanner.num_lor_endpoints_per_module[0] * end_mods + end_inds -) -# repeat number of TOF bin times here -num_tof_bins = projector.tof_parameters.num_tofbins +# ------------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------------- -sino_det_start_index = xp.reshape( - xp.stack(num_tof_bins * [sino_det_start_index], axis=-1), - sino_det_start_index.size * num_tof_bins, +event_det_id_1, event_det_id_2, event_tof_bin = utils.noisy_tof_sinogram_to_lm( + noisy_sinogram, lor_descriptor, xp, dev ) - -sino_det_end_index = xp.reshape( - xp.stack(num_tof_bins * [sino_det_end_index], axis=-1), - sino_det_end_index.size * num_tof_bins, -) - -# ---------------------------------------------------------------------------- -# ---------------------------------------------------------------------------- -# --- convert the index sinograms in to listmode data ------------------------ -# ---------------------------------------------------------------------------- -# ---------------------------------------------------------------------------- - -# generate listmode data from the noisy sinogram -event_sino_inds = np.repeat(np.arange(noisy_sinogram.shape[0]), noisy_sinogram) -# shuffle the event sinogram indices -np.random.shuffle(event_sino_inds) -# convert event sino indices to xp array -event_sino_inds = xp.asarray(event_sino_inds, device=dev) - -event_det_id_1 = xp.take(sino_det_start_index, event_sino_inds) -event_det_id_2 = xp.take(sino_det_end_index, event_sino_inds) -event_tof_bin = event_sino_inds % num_tof_bins - print(f"number of simulated events: {event_det_id_1.shape[0]}") # ---------------------------------------------------------------------------- diff --git a/python/parallelproj_utils.py b/python/parallelproj_utils.py index 1953103..b76a1b5 100644 --- a/python/parallelproj_utils.py +++ b/python/parallelproj_utils.py @@ -4,7 +4,7 @@ import abc from dataclasses import dataclass import array_api_compat.numpy as np -import numpy.typing as npt +from numpy.array_api._array_object import Array import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.art3d import Line3DCollection @@ -62,7 +62,7 @@ def __init__( xp: ModuleType, dev: str, num_lor_endpoints: int, - affine_transformation_matrix: npt.NDArray | None = None) -> None: + affine_transformation_matrix: Array | None = None) -> None: """abstract base class for PET scanner module Parameters @@ -73,7 +73,7 @@ def __init__( device to use for storing the LOR endpoints num_lor_endpoints : int number of LOR endpoints in the module - affine_transformation_matrix : npt.NDArray | None, optional + affine_transformation_matrix : Array | None, optional 4x4 affine transformation matrix applied to the LOR endpoint coordinates, default None if None, the 4x4 identity matrix is used """ @@ -111,39 +111,39 @@ def num_lor_endpoints(self) -> int: return self._num_lor_endpoints @property - def lor_endpoint_numbers(self) -> npt.NDArray: + def lor_endpoint_numbers(self) -> Array: """array enumerating all the LOR endpoints in the module Returns ------- - npt.NDArray + Array """ return self._lor_endpoint_numbers @property - def affine_transformation_matrix(self) -> npt.NDArray: + def affine_transformation_matrix(self) -> Array: """4x4 affine transformation matrix Returns ------- - npt.NDArray + Array """ return self._affine_transformation_matrix @abc.abstractmethod def get_raw_lor_endpoints(self, - inds: npt.NDArray | None = None) -> npt.NDArray: + inds: Array | None = None) -> Array: """mapping from LOR endpoint indices within module to an array of "raw" world coordinates Parameters ---------- - inds : npt.NDArray | None, optional + inds : Array | None, optional an non-negative integer array of indices, default None if None means all possible indices [0, ... , num_lor_endpoints - 1] Returns ------- - npt.NDArray + Array a 3 x len(inds) float array with the world coordinates of the LOR endpoints """ if inds is None: @@ -151,18 +151,18 @@ def get_raw_lor_endpoints(self, raise NotImplementedError def get_lor_endpoints(self, - inds: npt.NDArray | None = None) -> npt.NDArray: + inds: Array | None = None) -> Array: """mapping from LOR endpoint indices within module to an array of "transformed" world coordinates Parameters ---------- - inds : npt.NDArray | None, optional + inds : Array | None, optional an non-negative integer array of indices, default None if None means all possible indices [0, ... , num_lor_endpoints - 1] Returns ------- - npt.NDArray + Array a 3 x len(inds) float array with the world coordinates of the LOR endpoints including an affine transformation """ @@ -236,7 +236,7 @@ def __init__( lor_spacing: float, ax0: int = 2, ax1: int = 1, - affine_transformation_matrix: npt.NDArray | None = None) -> None: + affine_transformation_matrix: Array | None = None) -> None: """regular Polygon PET scanner module Parameters @@ -257,7 +257,7 @@ def __init__( axis number for the first direction, by default 2 ax1 : int, optional axis number for the second direction, by default 1 - affine_transformation_matrix : npt.NDArray | None, optional + affine_transformation_matrix : Array | None, optional 4x4 affine transformation matrix applied to the LOR endpoint coordinates, default None if None, the 4x4 identity matrix is used """ @@ -333,7 +333,7 @@ def lor_spacing(self) -> float: # abstract method from base class to be implemented def get_raw_lor_endpoints(self, - inds: npt.NDArray | None = None) -> npt.NDArray: + inds: Array | None = None) -> Array: if inds is None: inds = self.lor_endpoint_numbers @@ -418,7 +418,7 @@ def num_modules(self) -> int: return self._num_modules @property - def num_lor_endpoints_per_module(self) -> npt.NDArray: + def num_lor_endpoints_per_module(self) -> Array: """numpy array showing how many LOR endpoints are in every module""" return self._num_lor_endpoints_per_module @@ -428,17 +428,17 @@ def num_lor_endpoints(self) -> int: return self._num_lor_endpoints @property - def all_lor_endpoints_index_offset(self) -> npt.NDArray: + def all_lor_endpoints_index_offset(self) -> Array: """the offset in the linear (flattend) index for all LOR endpoints""" return self._all_lor_endpoints_index_offset @property - def all_lor_endpoints_module_number(self) -> npt.NDArray: + def all_lor_endpoints_module_number(self) -> Array: """the module number of all LOR endpoints""" return self._all_lor_endpoints_module_number @property - def all_lor_endpoints(self) -> npt.NDArray: + def all_lor_endpoints(self) -> Array: """the world coordinates of all LOR endpoints""" return self._all_lor_endpoints @@ -453,21 +453,21 @@ def dev(self) -> str: def linear_lor_endpoint_index( self, - module: npt.NDArray, - index_in_module: npt.NDArray, - ) -> npt.NDArray: + module: Array, + index_in_module: Array, + ) -> Array: """transform the module + index_in_modules indices into a flattened / linear LOR endpoint index Parameters ---------- - module : npt.NDArray + module : Array containing module numbers - index_in_module : npt.NDArray + index_in_module : Array containing index in modules Returns ------- - npt.NDArray + Array the flattened LOR endpoint index """ # index_in_module = self._xp.asarray(index_in_module) @@ -475,20 +475,20 @@ def linear_lor_endpoint_index( return self.xp.take(self.all_lor_endpoints_index_offset, module) + index_in_module - def get_lor_endpoints(self, module: npt.NDArray, - index_in_module: npt.NDArray) -> npt.NDArray: + def get_lor_endpoints(self, module: Array, + index_in_module: Array) -> Array: """get the coordinates for LOR endpoints defined by module and index in module Parameters ---------- - module : npt.NDArray + module : Array the module number of the LOR endpoints - index_in_module : npt.NDArray + index_in_module : Array the index in module number of the LOR endpoints Returns ------- - npt.NDArray | cpt.NDArray + Array | cpt.NDArray the 3 world coordinates of the LOR endpoints """ return self.xp.take(self.all_lor_endpoints, @@ -531,7 +531,7 @@ class RegularPolygonPETScannerGeometry(ModularizedPETScannerGeometry): def __init__(self, xp: ModuleType, dev: str, radius: float, num_sides: int, num_lor_endpoints_per_side: int, lor_spacing: float, - num_rings: int, ring_positions: npt.NDArray, + num_rings: int, ring_positions: Array, symmetry_axis: int) -> None: """ Parameters @@ -550,7 +550,7 @@ def __init__(self, xp: ModuleType, dev: str, radius: float, num_sides: int, spacing between the LOR endpoints in each side num_rings : int the number of rings (regular polygons) - ring_positions : npt.NDArray + ring_positions : Array 1D array with the coordinate of the rings along the ring axis symmetry_axis : int the ring axis (0,1,2) @@ -592,8 +592,7 @@ def __init__(self, xp: ModuleType, dev: str, radius: float, num_sides: int, ax0=self._ax0, ax1=self._ax1)) - modules = tuple(modules) - super().__init__(modules) + super().__init__(tuple(modules)) self._all_lor_endpoints_index_in_ring = self.xp.arange( self.num_lor_endpoints, device=dev @@ -631,12 +630,12 @@ def symmetry_axis(self) -> int: return self._symmetry_axis @property - def all_lor_endpoints_ring_number(self) -> npt.NDArray: + def all_lor_endpoints_ring_number(self) -> Array: """the ring (regular polygon) number of all LOR endpoints""" return self._all_lor_endpoints_module_number @property - def all_lor_endpoints_index_in_ring(self) -> npt.NDArray: + def all_lor_endpoints_index_in_ring(self) -> Array: """the index withing the ring (regular polygon) number of all LOR endpoints""" return self._all_lor_endpoints_index_in_ring @@ -646,7 +645,7 @@ def num_lor_endpoints_per_ring(self) -> int: return int(self._num_lor_endpoints_per_module[0]) @property - def ring_positions(self) -> npt.NDArray: + def ring_positions(self) -> Array: """the ring (regular polygon) positions""" return self._ring_positions @@ -675,7 +674,7 @@ def __init__(self, class PETLORDescriptor(abc.ABC): - """abstract base class to describe which modules / indices in modules of a + """abstract base class to describe which modules / indices in modules of a modularized PET scanner are in coincidence; defining geometrical LORs""" def __init__(self, scanner: ModularizedPETScannerGeometry) -> None: @@ -683,13 +682,17 @@ def __init__(self, scanner: ModularizedPETScannerGeometry) -> None: Parameters ---------- scanner : ModularizedPETScannerGeometry - a modularized PET scanner + a modularized PET scanner """ self._scanner = scanner @abc.abstractmethod - def get_lor_coordinates(self, - **kwargs) -> tuple[npt.ArrayLike, npt.ArrayLike]: + def get_lor_indices(self) -> tuple[Array, Array, Array, Array]: + """return the start and end indices of all LORs / or a subset of LORs""" + raise NotImplementedError + + @abc.abstractmethod + def get_lor_coordinates(self) -> tuple[Array, Array]: """return the start and end coordinates of all (or a subset of) LORs""" raise NotImplementedError @@ -716,9 +719,10 @@ def __init__( scanner: RegularPolygonPETScannerGeometry, radial_trim: int = 3, max_ring_difference: int | None = None, + sinogram_oder: SinogramSpatialAxisOrder = SinogramSpatialAxisOrder.RVP ) -> None: """Coincidence descriptor for a regular polygon PET scanner where - we have coincidences within and between "rings (polygons of modules)" + we have coincidences within and between "rings (polygons of modules)" The geometrical LORs can be sorted into a sinogram having a "plane", "view" and "radial" axis. @@ -731,24 +735,34 @@ def __init__( max_ring_difference : int | None, optional maximim ring difference to consider for coincidences, by default None means all ring differences are included + sinogram_order : SinogramSpatialAxisOrder, optional + the order of the sinogram axes, by default SinogramSpatialAxisOrder.RVP """ super().__init__(scanner) + self._scanner = scanner self._radial_trim = radial_trim if max_ring_difference is None: - self._max_ring_difference = self.scanner.num_rings - 1 + self._max_ring_difference = scanner.num_rings - 1 else: self._max_ring_difference = max_ring_difference - self._num_rad = (self.scanner.num_lor_endpoints_per_ring + + self._num_rad = (scanner.num_lor_endpoints_per_ring + 1) - 2 * self._radial_trim - self._num_views = self.scanner.num_lor_endpoints_per_ring // 2 + self._num_views = scanner.num_lor_endpoints_per_ring // 2 + + self._sinogram_order = sinogram_oder self._setup_plane_indices() self._setup_view_indices() + @property + def scanner(self) -> RegularPolygonPETScannerGeometry: + """the scanner for which coincidences are described""" + return self._scanner + @property def radial_trim(self) -> int: """number of geometrial LORs to disregard in the radial direction""" @@ -775,25 +789,30 @@ def num_views(self) -> int: return self._num_views @property - def start_plane_index(self) -> npt.NDArray: + def start_plane_index(self) -> Array: """start plane for all planes""" return self._start_plane_index @property - def end_plane_index(self) -> npt.NDArray: + def end_plane_index(self) -> Array: """end plane for all planes""" return self._end_plane_index @property - def start_in_ring_index(self) -> npt.NDArray: + def start_in_ring_index(self) -> Array: """start index within ring for all views - shape (num_view, num_rad)""" return self._start_in_ring_index @property - def end_in_ring_index(self) -> npt.NDArray: + def end_in_ring_index(self) -> Array: """end index within ring for all views - shape (num_view, num_rad)""" return self._end_in_ring_index + @property + def sinogram_order(self) -> SinogramSpatialAxisOrder: + """the order of the sinogram axes""" + return self._sinogram_order + def _setup_plane_indices(self) -> None: """setup the start / end plane indices (similar to a Michelogram) """ @@ -855,21 +874,17 @@ def _setup_view_indices(self) -> None: def get_lor_indices( self, - views: None | npt.ArrayLike = None, - sinogram_order=SinogramSpatialAxisOrder.RVP - ) -> tuple[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike, npt.ArrayLike]: + views: None | Array = None + ) -> tuple[Array, Array, Array, Array]: """return the start and end indices of all LORs / or a subset of views Parameters ---------- - views : None | npt.ArrayLike, optional + views : None | Array, optional the views to consider, by default None means all views - sinogram_order : SinogramSpatialAxisOrder, optional - the order of the sinogram axes, by default SinogramSpatialAxisOrder.RVP - Returns ------- - start_mods, end_mods, start_inds, end_inds + start_mods, end_mods, start_inds, end_inds """ if views is None: @@ -896,17 +911,19 @@ def get_lor_indices( start_inds = self.xp.reshape(start_inds, sinogram_spatial_shape) end_inds = self.xp.reshape(end_inds, sinogram_spatial_shape) - if sinogram_order is not SinogramSpatialAxisOrder.PVR: - if sinogram_order is SinogramSpatialAxisOrder.RVP: + if self.sinogram_order is not SinogramSpatialAxisOrder.PVR: + if self.sinogram_order is SinogramSpatialAxisOrder.RVP: new_order = (2, 1, 0) - elif sinogram_order is SinogramSpatialAxisOrder.RPV: + elif self.sinogram_order is SinogramSpatialAxisOrder.RPV: new_order = (2, 0, 1) - elif sinogram_order is SinogramSpatialAxisOrder.VRP: + elif self.sinogram_order is SinogramSpatialAxisOrder.VRP: new_order = (1, 2, 0) - elif sinogram_order is SinogramSpatialAxisOrder.VPR: + elif self.sinogram_order is SinogramSpatialAxisOrder.VPR: new_order = (1, 0, 2) - elif sinogram_order is SinogramSpatialAxisOrder.PRV: + elif self.sinogram_order is SinogramSpatialAxisOrder.PRV: new_order = (0, 2, 1) + else: + new_order = (0, 1, 2) start_mods = self.xp.permute_dims(start_mods, new_order) end_mods = self.xp.permute_dims(end_mods, new_order) @@ -914,32 +931,31 @@ def get_lor_indices( start_inds = self.xp.permute_dims(start_inds, new_order) end_inds = self.xp.permute_dims(end_inds, new_order) - return start_mods, end_mods, start_inds, end_inds + return start_mods, end_mods, start_inds, end_inds def get_lor_coordinates( self, - views: None | npt.ArrayLike = None, - sinogram_order=SinogramSpatialAxisOrder.RVP - ) -> tuple[npt.ArrayLike, npt.ArrayLike]: + views: None | Array = None, + ) -> tuple[Array, Array]: """return the start and end coordinates of all LORs / or a subset of views Parameters ---------- - views : None | npt.ArrayLike, optional + views : None | Array, optional the views to consider, by default None means all views sinogram_order : SinogramSpatialAxisOrder, optional the order of the sinogram axes, by default SinogramSpatialAxisOrder.RVP Returns ------- - xstart, xend : npt.ArrayLike + xstart, xend : Array 2 dimensional floating point arrays containing the start and end coordinates of all LORs """ - start_mods, end_mods, start_inds, end_inds = self.get_lor_indices(views, sinogram_order) + start_mods, end_mods, start_inds, end_inds = self.get_lor_indices(views) sinogram_spatial_shape = start_mods.shape - + start_mods = self.xp.reshape(start_mods, (-1, )) start_inds = self.xp.reshape(start_inds, (-1, )) @@ -957,8 +973,8 @@ def get_lor_coordinates( def show_views(self, ax: plt.Axes, - views: npt.ArrayLike, - planes: npt.ArrayLike, + views: Array, + planes: Array, lw: float = 0.2, **kwargs) -> None: """show all LORs of a single view in a given plane @@ -975,8 +991,7 @@ def show_views(self, the line width, by default 0.2 """ - xs, xe = self.get_lor_coordinates( - views=views, sinogram_order=SinogramSpatialAxisOrder.RVP) + xs, xe = self.get_lor_coordinates(views=views) xs = self.xp.reshape(self.xp.take(xs, planes, axis=2), (-1, 3)) xe = self.xp.reshape(self.xp.take(xe, planes, axis=2), (-1, 3)) @@ -1015,8 +1030,8 @@ def __init__(self, lor_descriptor: RegularPolygonPETLORDescriptor, img_shape: tuple[int, int, int], voxel_size: tuple[float, float, float], - img_origin: None | npt.ArrayLike = None, - views: None | npt.ArrayLike = None, + img_origin: None | Array = None, + views: None | Array = None, resolution_model: None | parallelproj.LinearOperator = None, tof: bool = False): """Regular polygon PET projector @@ -1029,10 +1044,10 @@ def __init__(self, shape of the image to be projected voxel_size : tuple[float, float, float] the voxel size of the image to be projected - img_origin : None | npt.ArrayLike, optional - the origin of the image to be projected, by default None + img_origin : None | Array, optional + the origin of the image to be projected, by default None means that image is "centered" in the scanner - views : None | npt.ArrayLike, optional + views : None | Array, optional sinogram views to be projected, by default None means that all views are being projected resolution_model : None | parallelproj.LinearOperator, optional @@ -1068,8 +1083,7 @@ def __init__(self, self._resolution_model = resolution_model - self._xstart, self._xend = lor_descriptor.get_lor_coordinates( - views=self._views, sinogram_order=SinogramSpatialAxisOrder['RVP']) + self._xstart, self._xend = lor_descriptor.get_lor_coordinates(views=self._views) self._tof = tof self._tof_parameters = TOFParameters() @@ -1115,7 +1129,7 @@ def tof_parameters(self, value: TOFParameters) -> None: self._tof_parameters = value @property - def img_origin(self) -> npt.NDArray: + def img_origin(self) -> Array: return self._img_origin def _apply(self, x): @@ -1185,7 +1199,7 @@ def distributed_subset_order(n: int) -> list[int]: Returns ------- list[int] - """ + """ l = [x for x in range(n)] o = [] @@ -1196,4 +1210,3 @@ def distributed_subset_order(n: int) -> list[int]: o.append(l.pop(len(l)//2)) return o - diff --git a/python/utils.py b/python/utils.py new file mode 100644 index 0000000..113d0ab --- /dev/null +++ b/python/utils.py @@ -0,0 +1,79 @@ +import array_api_compat.numpy as np +import parallelproj_utils +from numpy.array_api._array_object import Array +from types import ModuleType + + +def noisy_tof_sinogram_to_lm( + noisy_sinogram: Array, + lor_descriptor: parallelproj_utils.PETLORDescriptor, + xp: ModuleType, + dev: str, +) -> tuple[Array, Array, Array]: + """convert a noisy sinogram to listmode data + + Parameters + ---------- + noisy_sinogram : Array + sinogram containing Poisson noise (integer values) + lor_descriptor : parallelproj_utils.PETLORDescriptor + description of the LOR geometry + xp : ModuleType + array module + dev : str + device + + Returns + ------- + tuple[Array, Array, Array] + event_det_id_1, event_det_id_2, event_tof_bin + """ + if noisy_sinogram.ndim != 4: + raise ValueError( + f"noisy_sinogram must be a 4D array, the last axis must be the TOF axis, but has shape {noisy_sinogram.shape}" + ) + + num_tof_bins = noisy_sinogram.shape[3] + + # ravel the noisy sinogram and the detector start and end "index" sinograms + noisy_sinogram = xp.reshape(noisy_sinogram, (noisy_sinogram.size,)) + + # get the two dimensional indices of all sinogram bins + start_mods, end_mods, start_inds, end_inds = lor_descriptor.get_lor_indices() + + # generate two sinograms that contain the linearized detector start and end indices + sino_det_start_index = ( + lor_descriptor.scanner.num_lor_endpoints_per_module[0] * start_mods + start_inds + ) + sino_det_end_index = ( + lor_descriptor.scanner.num_lor_endpoints_per_module[0] * end_mods + end_inds + ) + + sino_det_start_index = xp.reshape( + xp.stack(num_tof_bins * [sino_det_start_index], axis=-1), + sino_det_start_index.size * num_tof_bins, + ) + + sino_det_end_index = xp.reshape( + xp.stack(num_tof_bins * [sino_det_end_index], axis=-1), + sino_det_end_index.size * num_tof_bins, + ) + + # ---------------------------------------------------------------------------- + # ---------------------------------------------------------------------------- + # --- convert the index sinograms in to listmode data ------------------------ + # ---------------------------------------------------------------------------- + # ---------------------------------------------------------------------------- + + # generate listmode data from the noisy sinogram + event_sino_inds = np.repeat(np.arange(noisy_sinogram.shape[0]), noisy_sinogram) + # shuffle the event sinogram indices + np.random.shuffle(event_sino_inds) + # convert event sino indices to xp array + event_sino_inds = xp.asarray(event_sino_inds, device=dev) + + event_det_id_1 = xp.take(sino_det_start_index, event_sino_inds) + event_det_id_2 = xp.take(sino_det_end_index, event_sino_inds) + event_tof_bin = event_sino_inds % num_tof_bins + + return event_det_id_1, event_det_id_2, event_tof_bin From 46090a350ff4964521f9f586bee017c7be1a7deb Mon Sep 17 00:00:00 2001 From: Georg Schramm Date: Thu, 16 Nov 2023 22:20:32 +0100 Subject: [PATCH 2/2] simulate and write two time blocks with different images --- python/00_simulate_lm_data.py | 72 ++++++++++----- python/01_reconstruct_lm_data.py | 7 +- python/prd_io.py | 147 ++++++++++++++++--------------- 3 files changed, 132 insertions(+), 94 deletions(-) diff --git a/python/00_simulate_lm_data.py b/python/00_simulate_lm_data.py index b2c71af..89d1891 100644 --- a/python/00_simulate_lm_data.py +++ b/python/00_simulate_lm_data.py @@ -61,15 +61,20 @@ img_shape = (num_trans, num_trans, num_ax) n0, n1, n2 = img_shape -# setup an image containing a box - -img = xp.asarray( +# setup the image for the 1st time frame +img_f1 = xp.asarray( np.tile(np.load("../data/SL.npy")[..., None], (1, 1, num_ax)), device=dev, dtype=xp.float32, ) -img[:, :, :2] = 0 -img[:, :, -2:] = 0 +img_f1[:, :, :2] = 0 +img_f1[:, :, -2:] = 0 + +img_f1 /= xp.max(img_f1) + +# setup the image for the 2nd time frame +img_f2 = xp.sqrt(img_f1) + # ---------------------------------------------------------------------------- # ---------------------------------------------------------------------------- @@ -89,31 +94,41 @@ num_tof_bins = projector.tof_parameters.num_tofbins # forward project the image -noise_free_sinogram = projector(img) +noise_free_sinogram_f1 = projector(img_f1) +noise_free_sinogram_f2 = projector(img_f2) # rescale the forward projection and image such that we get the expected number of trues -scale = expected_num_trues / float(xp.sum(noise_free_sinogram)) -noise_free_sinogram *= scale -img *= scale +scale = expected_num_trues / float(xp.sum(noise_free_sinogram_f1)) +noise_free_sinogram_f1 *= scale +noise_free_sinogram_f2 *= scale +img_f1 *= scale +img_f2 *= scale # calculate the sensitivity image sens_img = projector.adjoint( - xp.ones(noise_free_sinogram.shape, device=dev, dtype=xp.float32) + xp.ones(noise_free_sinogram_f1.shape, device=dev, dtype=xp.float32) ) # add poisson noise to the noise free sinogram -noisy_sinogram = xp.asarray( - np.random.poisson(np.asarray(to_device(noise_free_sinogram, "cpu"))), device=dev +noisy_sinogram_f1 = xp.asarray( + np.random.poisson(np.asarray(to_device(noise_free_sinogram_f1, "cpu"))), device=dev +) +noisy_sinogram_f2 = xp.asarray( + np.random.poisson(np.asarray(to_device(noise_free_sinogram_f2, "cpu"))), device=dev ) # ------------------------------------------------------------------------------------- # ------------------------------------------------------------------------------------- # ------------------------------------------------------------------------------------- -event_det_id_1, event_det_id_2, event_tof_bin = utils.noisy_tof_sinogram_to_lm( - noisy_sinogram, lor_descriptor, xp, dev +event_det_id_1_f1, event_det_id_2_f1, event_tof_bin_f1 = utils.noisy_tof_sinogram_to_lm( + noisy_sinogram_f1, lor_descriptor, xp, dev ) -print(f"number of simulated events: {event_det_id_1.shape[0]}") +event_det_id_1_f2, event_det_id_2_f2, event_tof_bin_f2 = utils.noisy_tof_sinogram_to_lm( + noisy_sinogram_f2, lor_descriptor, xp, dev +) +print(f"number of simulated events f1: {event_det_id_1_f1.shape[0]}") +print(f"number of simulated events f2: {event_det_id_1_f2.shape[0]}") # ---------------------------------------------------------------------------- # ---------------------------------------------------------------------------- @@ -153,10 +168,10 @@ # write the data to PETSIRD write_prd_from_numpy_arrays( - event_det_id_1, - event_det_id_2, + [event_det_id_1_f1, event_det_id_1_f2], + [event_det_id_2_f1, event_det_id_2_f2], scanner_information, - tof_idx_array=event_tof_bin, + tof_idx_array_blocks=[event_tof_bin_f1, event_tof_bin_f2], output_file=str(Path(output_dir) / output_prd_file), ) print(f"wrote PETSIRD LM file to {str(Path(output_dir) / output_prd_file)}") @@ -179,13 +194,22 @@ fig_dir = Path("../figs") fig_dir.mkdir(exist_ok=True) -vmax = 1.2 * xp.max(img) -fig, ax = plt.subplots(1, img.shape[2], figsize=(img.shape[2] * 2, 2)) -for i in range(img.shape[2]): - ax[i].imshow( - xp.asarray(to_device(img[:, :, i], "cpu")), vmin=0, vmax=vmax, cmap="Greys" +vmax = 1.2 * xp.max(img_f1) +fig, ax = plt.subplots( + 2, img_shape[2], figsize=(img_shape[2] * 2, 2 * 2), sharex=True, sharey=True +) +for i in range(img_shape[2]): + ax[0, i].imshow( + xp.asarray(to_device(img_f1[:, :, i], "cpu")), vmin=0, vmax=vmax, cmap="Greys" ) - ax[i].set_title(f"ground truth sl {i+1}", fontsize="small") + ax[0, i].set_title(f"sl {i+1}", fontsize="small") + + ax[1, i].imshow( + xp.asarray(to_device(img_f2[:, :, i], "cpu")), vmin=0, vmax=vmax, cmap="Greys" + ) + +ax[0, 0].set_ylabel("ground truth frame 1", fontsize="small") +ax[1, 0].set_ylabel("ground truth frame 2", fontsize="small") fig.tight_layout() fig.savefig(fig_dir / "simulated_phantom.png") diff --git a/python/01_reconstruct_lm_data.py b/python/01_reconstruct_lm_data.py index d447255..a3fc9f6 100644 --- a/python/01_reconstruct_lm_data.py +++ b/python/01_reconstruct_lm_data.py @@ -31,7 +31,12 @@ # read the LM file header and all event attributes header, event_attributes = read_prd_to_numpy_arrays( - str(Path(lm_data_dir) / prd_file), xp, dev, read_tof=None, read_energy=False + str(Path(lm_data_dir) / prd_file), + xp, + dev, + read_tof=None, + read_energy=False, + time_block_ids=range(1, 2), ) # read the detector coordinates into a 2D lookup table diff --git a/python/prd_io.py b/python/prd_io.py index 3836ddf..259f8ae 100644 --- a/python/prd_io.py +++ b/python/prd_io.py @@ -6,15 +6,16 @@ from numpy.array_api._array_object import Array from types import ModuleType +from typing import Sequence def write_prd_from_numpy_arrays( - detector_1_id_array: Array, - detector_2_id_array: Array, + detector_1_id_array_blocks: list[Array], + detector_2_id_array_blocks: list[Array], scanner_information: prd.ScannerInformation, - tof_idx_array: Array | None = None, - energy_1_idx_array: Array | None = None, - energy_2_idx_array: Array | None = None, + tof_idx_array_blocks: list[Array] | None = None, + energy_1_idx_array_blocks: list[Array] | None = None, + energy_2_idx_array_blocks: list[Array] | None = None, output_file: str | None = None, ) -> None: """Write a PRD file from numpy arrays. Currently all into one time block @@ -38,53 +39,56 @@ def write_prd_from_numpy_arrays( output file, if None write to stdout """ - num_events: int = detector_1_id_array.size + time_blocks = [] - events = [] - for i in range(num_events): - det_id_1 = int(detector_1_id_array[i]) - det_id_2 = int(detector_2_id_array[i]) + for id in range(len(detector_1_id_array_blocks)): + num_events: int = detector_1_id_array_blocks[id].size - if tof_idx_array is not None: - tof_idx = int(tof_idx_array[i]) - else: - tof_idx = 0 + events = [] + for i in range(num_events): + det_id_1 = int(detector_1_id_array_blocks[id][i]) + det_id_2 = int(detector_2_id_array_blocks[id][i]) - if energy_1_idx_array is not None: - energy_1_idx = int(energy_1_idx_array[i]) - else: - energy_1_idx = 0 + if tof_idx_array_blocks is not None: + tof_idx = int(tof_idx_array_blocks[id][i]) + else: + tof_idx = 0 - if energy_2_idx_array is not None: - energy_2_idx = int(energy_2_idx_array[i]) - else: - energy_2_idx = 0 - - events.append( - prd.CoincidenceEvent( - detector_1_id=det_id_1, - detector_2_id=det_id_2, - tof_idx=tof_idx, - energy_1_idx=energy_1_idx, - energy_2_idx=energy_2_idx, + if energy_1_idx_array_blocks is not None: + energy_1_idx = int(energy_1_idx_array_blocks[id][i]) + else: + energy_1_idx = 0 + + if energy_2_idx_array_blocks is not None: + energy_2_idx = int(energy_2_idx_array_blocks[id][i]) + else: + energy_2_idx = 0 + + events.append( + prd.CoincidenceEvent( + detector_1_id=det_id_1, + detector_2_id=det_id_2, + tof_idx=tof_idx, + energy_1_idx=energy_1_idx, + energy_2_idx=energy_2_idx, + ) ) - ) - time_block = prd.TimeBlock(id=0, prompt_events=events) + time_blocks.append(prd.TimeBlock(id=id, prompt_events=events)) if output_file is None: with prd.BinaryPrdExperimentWriter(sys.stdout.buffer) as writer: writer.write_header(prd.Header(scanner=scanner_information)) - writer.write_time_blocks((time_block,)) + writer.write_time_blocks(time_blocks) else: if output_file.endswith(".ndjson"): with prd.NDJsonPrdExperimentWriter(output_file) as writer: writer.write_header(prd.Header(scanner=scanner_information)) - writer.write_time_blocks((time_block,)) + writer.write_time_blocks(time_blocks) else: with prd.BinaryPrdExperimentWriter(output_file) as writer: writer.write_header(prd.Header(scanner=scanner_information)) - writer.write_time_blocks((time_block,)) + writer.write_time_blocks(time_blocks) def read_prd_to_numpy_arrays( @@ -93,6 +97,7 @@ def read_prd_to_numpy_arrays( dev: str, read_tof: bool | None = None, read_energy: bool | None = None, + time_block_ids: Sequence[int] | None = None, ) -> tuple[prd.types.Header, Array]: """Read all time blocks of a PETSIRD listmode file @@ -137,44 +142,48 @@ def read_prd_to_numpy_arrays( r_energy = read_energy # loop over all time blocks and read all meaningful event attributes + event_attribute_list = [] + for time_block in reader.read_time_blocks(): - if r_tof and r_energy: - event_attribute_list = [ - [ - e.detector_1_id, - e.detector_2_id, - e.tof_idx, - e.energy_1_idx, - e.energy_2_idx, + if (time_block_ids is None) or time_block.id in time_block_ids: + print(f"reading time block {time_block.id}") + if r_tof and r_energy: + event_attribute_list += [ + [ + e.detector_1_id, + e.detector_2_id, + e.tof_idx, + e.energy_1_idx, + e.energy_2_idx, + ] + for e in time_block.prompt_events ] - for e in time_block.prompt_events - ] - elif r_tof and (not r_energy): - event_attribute_list = [ - [ - e.detector_1_id, - e.detector_2_id, - e.tof_idx, + elif r_tof and (not r_energy): + event_attribute_list += [ + [ + e.detector_1_id, + e.detector_2_id, + e.tof_idx, + ] + for e in time_block.prompt_events ] - for e in time_block.prompt_events - ] - elif (not r_tof) and r_energy: - event_attribute_list = [ - [ - e.detector_1_id, - e.detector_2_id, - e.energy_1_idx, - e.energy_2_idx, + elif (not r_tof) and r_energy: + event_attribute_list += [ + [ + e.detector_1_id, + e.detector_2_id, + e.energy_1_idx, + e.energy_2_idx, + ] + for e in time_block.prompt_events ] - for e in time_block.prompt_events - ] - else: - event_attribute_list = [ - [ - e.detector_1_id, - e.detector_2_id, + else: + event_attribute_list += [ + [ + e.detector_1_id, + e.detector_2_id, + ] + for e in time_block.prompt_events ] - for e in time_block.prompt_events - ] return header, xp.asarray(event_attribute_list, device=dev)