diff --git a/python/00_simulate_lm_data.py b/python/00_simulate_lm_data.py index 7b0ca1a..b2c71af 100644 --- a/python/00_simulate_lm_data.py +++ b/python/00_simulate_lm_data.py @@ -6,6 +6,7 @@ import prd import parallelproj import parallelproj_utils +import utils import array_api_compat.numpy as np import matplotlib.pyplot as plt from array_api_compat import to_device @@ -84,6 +85,9 @@ ) projector.tof = True # set this to True to get a time of flight projector +# repeat number of TOF bin times here +num_tof_bins = projector.tof_parameters.num_tofbins + # forward project the image noise_free_sinogram = projector(img) @@ -101,50 +105,14 @@ noisy_sinogram = xp.asarray( np.random.poisson(np.asarray(to_device(noise_free_sinogram, "cpu"))), device=dev ) -# ravel the noisy sinogram and the detector start and end "index" sinograms -noisy_sinogram = xp.reshape(noisy_sinogram, (noisy_sinogram.size,)) - -# get the two dimensional indices of all sinogram bins -start_mods, end_mods, start_inds, end_inds = lor_descriptor.get_lor_indices() - -# generate two sinograms that contain the linearized detector start and end indices -sino_det_start_index = ( - lor_descriptor.scanner.num_lor_endpoints_per_module[0] * start_mods + start_inds -) -sino_det_end_index = ( - lor_descriptor.scanner.num_lor_endpoints_per_module[0] * end_mods + end_inds -) -# repeat number of TOF bin times here -num_tof_bins = projector.tof_parameters.num_tofbins +# ------------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------------- -sino_det_start_index = xp.reshape( - xp.stack(num_tof_bins * [sino_det_start_index], axis=-1), - sino_det_start_index.size * num_tof_bins, +event_det_id_1, event_det_id_2, event_tof_bin = utils.noisy_tof_sinogram_to_lm( + noisy_sinogram, lor_descriptor, xp, dev ) - -sino_det_end_index = xp.reshape( - xp.stack(num_tof_bins * [sino_det_end_index], axis=-1), - sino_det_end_index.size * num_tof_bins, -) - -# ---------------------------------------------------------------------------- -# ---------------------------------------------------------------------------- -# --- convert the index sinograms in to listmode data ------------------------ -# ---------------------------------------------------------------------------- -# ---------------------------------------------------------------------------- - -# generate listmode data from the noisy sinogram -event_sino_inds = np.repeat(np.arange(noisy_sinogram.shape[0]), noisy_sinogram) -# shuffle the event sinogram indices -np.random.shuffle(event_sino_inds) -# convert event sino indices to xp array -event_sino_inds = xp.asarray(event_sino_inds, device=dev) - -event_det_id_1 = xp.take(sino_det_start_index, event_sino_inds) -event_det_id_2 = xp.take(sino_det_end_index, event_sino_inds) -event_tof_bin = event_sino_inds % num_tof_bins - print(f"number of simulated events: {event_det_id_1.shape[0]}") # ---------------------------------------------------------------------------- diff --git a/python/parallelproj_utils.py b/python/parallelproj_utils.py index 1953103..b76a1b5 100644 --- a/python/parallelproj_utils.py +++ b/python/parallelproj_utils.py @@ -4,7 +4,7 @@ import abc from dataclasses import dataclass import array_api_compat.numpy as np -import numpy.typing as npt +from numpy.array_api._array_object import Array import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.art3d import Line3DCollection @@ -62,7 +62,7 @@ def __init__( xp: ModuleType, dev: str, num_lor_endpoints: int, - affine_transformation_matrix: npt.NDArray | None = None) -> None: + affine_transformation_matrix: Array | None = None) -> None: """abstract base class for PET scanner module Parameters @@ -73,7 +73,7 @@ def __init__( device to use for storing the LOR endpoints num_lor_endpoints : int number of LOR endpoints in the module - affine_transformation_matrix : npt.NDArray | None, optional + affine_transformation_matrix : Array | None, optional 4x4 affine transformation matrix applied to the LOR endpoint coordinates, default None if None, the 4x4 identity matrix is used """ @@ -111,39 +111,39 @@ def num_lor_endpoints(self) -> int: return self._num_lor_endpoints @property - def lor_endpoint_numbers(self) -> npt.NDArray: + def lor_endpoint_numbers(self) -> Array: """array enumerating all the LOR endpoints in the module Returns ------- - npt.NDArray + Array """ return self._lor_endpoint_numbers @property - def affine_transformation_matrix(self) -> npt.NDArray: + def affine_transformation_matrix(self) -> Array: """4x4 affine transformation matrix Returns ------- - npt.NDArray + Array """ return self._affine_transformation_matrix @abc.abstractmethod def get_raw_lor_endpoints(self, - inds: npt.NDArray | None = None) -> npt.NDArray: + inds: Array | None = None) -> Array: """mapping from LOR endpoint indices within module to an array of "raw" world coordinates Parameters ---------- - inds : npt.NDArray | None, optional + inds : Array | None, optional an non-negative integer array of indices, default None if None means all possible indices [0, ... , num_lor_endpoints - 1] Returns ------- - npt.NDArray + Array a 3 x len(inds) float array with the world coordinates of the LOR endpoints """ if inds is None: @@ -151,18 +151,18 @@ def get_raw_lor_endpoints(self, raise NotImplementedError def get_lor_endpoints(self, - inds: npt.NDArray | None = None) -> npt.NDArray: + inds: Array | None = None) -> Array: """mapping from LOR endpoint indices within module to an array of "transformed" world coordinates Parameters ---------- - inds : npt.NDArray | None, optional + inds : Array | None, optional an non-negative integer array of indices, default None if None means all possible indices [0, ... , num_lor_endpoints - 1] Returns ------- - npt.NDArray + Array a 3 x len(inds) float array with the world coordinates of the LOR endpoints including an affine transformation """ @@ -236,7 +236,7 @@ def __init__( lor_spacing: float, ax0: int = 2, ax1: int = 1, - affine_transformation_matrix: npt.NDArray | None = None) -> None: + affine_transformation_matrix: Array | None = None) -> None: """regular Polygon PET scanner module Parameters @@ -257,7 +257,7 @@ def __init__( axis number for the first direction, by default 2 ax1 : int, optional axis number for the second direction, by default 1 - affine_transformation_matrix : npt.NDArray | None, optional + affine_transformation_matrix : Array | None, optional 4x4 affine transformation matrix applied to the LOR endpoint coordinates, default None if None, the 4x4 identity matrix is used """ @@ -333,7 +333,7 @@ def lor_spacing(self) -> float: # abstract method from base class to be implemented def get_raw_lor_endpoints(self, - inds: npt.NDArray | None = None) -> npt.NDArray: + inds: Array | None = None) -> Array: if inds is None: inds = self.lor_endpoint_numbers @@ -418,7 +418,7 @@ def num_modules(self) -> int: return self._num_modules @property - def num_lor_endpoints_per_module(self) -> npt.NDArray: + def num_lor_endpoints_per_module(self) -> Array: """numpy array showing how many LOR endpoints are in every module""" return self._num_lor_endpoints_per_module @@ -428,17 +428,17 @@ def num_lor_endpoints(self) -> int: return self._num_lor_endpoints @property - def all_lor_endpoints_index_offset(self) -> npt.NDArray: + def all_lor_endpoints_index_offset(self) -> Array: """the offset in the linear (flattend) index for all LOR endpoints""" return self._all_lor_endpoints_index_offset @property - def all_lor_endpoints_module_number(self) -> npt.NDArray: + def all_lor_endpoints_module_number(self) -> Array: """the module number of all LOR endpoints""" return self._all_lor_endpoints_module_number @property - def all_lor_endpoints(self) -> npt.NDArray: + def all_lor_endpoints(self) -> Array: """the world coordinates of all LOR endpoints""" return self._all_lor_endpoints @@ -453,21 +453,21 @@ def dev(self) -> str: def linear_lor_endpoint_index( self, - module: npt.NDArray, - index_in_module: npt.NDArray, - ) -> npt.NDArray: + module: Array, + index_in_module: Array, + ) -> Array: """transform the module + index_in_modules indices into a flattened / linear LOR endpoint index Parameters ---------- - module : npt.NDArray + module : Array containing module numbers - index_in_module : npt.NDArray + index_in_module : Array containing index in modules Returns ------- - npt.NDArray + Array the flattened LOR endpoint index """ # index_in_module = self._xp.asarray(index_in_module) @@ -475,20 +475,20 @@ def linear_lor_endpoint_index( return self.xp.take(self.all_lor_endpoints_index_offset, module) + index_in_module - def get_lor_endpoints(self, module: npt.NDArray, - index_in_module: npt.NDArray) -> npt.NDArray: + def get_lor_endpoints(self, module: Array, + index_in_module: Array) -> Array: """get the coordinates for LOR endpoints defined by module and index in module Parameters ---------- - module : npt.NDArray + module : Array the module number of the LOR endpoints - index_in_module : npt.NDArray + index_in_module : Array the index in module number of the LOR endpoints Returns ------- - npt.NDArray | cpt.NDArray + Array | cpt.NDArray the 3 world coordinates of the LOR endpoints """ return self.xp.take(self.all_lor_endpoints, @@ -531,7 +531,7 @@ class RegularPolygonPETScannerGeometry(ModularizedPETScannerGeometry): def __init__(self, xp: ModuleType, dev: str, radius: float, num_sides: int, num_lor_endpoints_per_side: int, lor_spacing: float, - num_rings: int, ring_positions: npt.NDArray, + num_rings: int, ring_positions: Array, symmetry_axis: int) -> None: """ Parameters @@ -550,7 +550,7 @@ def __init__(self, xp: ModuleType, dev: str, radius: float, num_sides: int, spacing between the LOR endpoints in each side num_rings : int the number of rings (regular polygons) - ring_positions : npt.NDArray + ring_positions : Array 1D array with the coordinate of the rings along the ring axis symmetry_axis : int the ring axis (0,1,2) @@ -592,8 +592,7 @@ def __init__(self, xp: ModuleType, dev: str, radius: float, num_sides: int, ax0=self._ax0, ax1=self._ax1)) - modules = tuple(modules) - super().__init__(modules) + super().__init__(tuple(modules)) self._all_lor_endpoints_index_in_ring = self.xp.arange( self.num_lor_endpoints, device=dev @@ -631,12 +630,12 @@ def symmetry_axis(self) -> int: return self._symmetry_axis @property - def all_lor_endpoints_ring_number(self) -> npt.NDArray: + def all_lor_endpoints_ring_number(self) -> Array: """the ring (regular polygon) number of all LOR endpoints""" return self._all_lor_endpoints_module_number @property - def all_lor_endpoints_index_in_ring(self) -> npt.NDArray: + def all_lor_endpoints_index_in_ring(self) -> Array: """the index withing the ring (regular polygon) number of all LOR endpoints""" return self._all_lor_endpoints_index_in_ring @@ -646,7 +645,7 @@ def num_lor_endpoints_per_ring(self) -> int: return int(self._num_lor_endpoints_per_module[0]) @property - def ring_positions(self) -> npt.NDArray: + def ring_positions(self) -> Array: """the ring (regular polygon) positions""" return self._ring_positions @@ -675,7 +674,7 @@ def __init__(self, class PETLORDescriptor(abc.ABC): - """abstract base class to describe which modules / indices in modules of a + """abstract base class to describe which modules / indices in modules of a modularized PET scanner are in coincidence; defining geometrical LORs""" def __init__(self, scanner: ModularizedPETScannerGeometry) -> None: @@ -683,13 +682,17 @@ def __init__(self, scanner: ModularizedPETScannerGeometry) -> None: Parameters ---------- scanner : ModularizedPETScannerGeometry - a modularized PET scanner + a modularized PET scanner """ self._scanner = scanner @abc.abstractmethod - def get_lor_coordinates(self, - **kwargs) -> tuple[npt.ArrayLike, npt.ArrayLike]: + def get_lor_indices(self) -> tuple[Array, Array, Array, Array]: + """return the start and end indices of all LORs / or a subset of LORs""" + raise NotImplementedError + + @abc.abstractmethod + def get_lor_coordinates(self) -> tuple[Array, Array]: """return the start and end coordinates of all (or a subset of) LORs""" raise NotImplementedError @@ -716,9 +719,10 @@ def __init__( scanner: RegularPolygonPETScannerGeometry, radial_trim: int = 3, max_ring_difference: int | None = None, + sinogram_oder: SinogramSpatialAxisOrder = SinogramSpatialAxisOrder.RVP ) -> None: """Coincidence descriptor for a regular polygon PET scanner where - we have coincidences within and between "rings (polygons of modules)" + we have coincidences within and between "rings (polygons of modules)" The geometrical LORs can be sorted into a sinogram having a "plane", "view" and "radial" axis. @@ -731,24 +735,34 @@ def __init__( max_ring_difference : int | None, optional maximim ring difference to consider for coincidences, by default None means all ring differences are included + sinogram_order : SinogramSpatialAxisOrder, optional + the order of the sinogram axes, by default SinogramSpatialAxisOrder.RVP """ super().__init__(scanner) + self._scanner = scanner self._radial_trim = radial_trim if max_ring_difference is None: - self._max_ring_difference = self.scanner.num_rings - 1 + self._max_ring_difference = scanner.num_rings - 1 else: self._max_ring_difference = max_ring_difference - self._num_rad = (self.scanner.num_lor_endpoints_per_ring + + self._num_rad = (scanner.num_lor_endpoints_per_ring + 1) - 2 * self._radial_trim - self._num_views = self.scanner.num_lor_endpoints_per_ring // 2 + self._num_views = scanner.num_lor_endpoints_per_ring // 2 + + self._sinogram_order = sinogram_oder self._setup_plane_indices() self._setup_view_indices() + @property + def scanner(self) -> RegularPolygonPETScannerGeometry: + """the scanner for which coincidences are described""" + return self._scanner + @property def radial_trim(self) -> int: """number of geometrial LORs to disregard in the radial direction""" @@ -775,25 +789,30 @@ def num_views(self) -> int: return self._num_views @property - def start_plane_index(self) -> npt.NDArray: + def start_plane_index(self) -> Array: """start plane for all planes""" return self._start_plane_index @property - def end_plane_index(self) -> npt.NDArray: + def end_plane_index(self) -> Array: """end plane for all planes""" return self._end_plane_index @property - def start_in_ring_index(self) -> npt.NDArray: + def start_in_ring_index(self) -> Array: """start index within ring for all views - shape (num_view, num_rad)""" return self._start_in_ring_index @property - def end_in_ring_index(self) -> npt.NDArray: + def end_in_ring_index(self) -> Array: """end index within ring for all views - shape (num_view, num_rad)""" return self._end_in_ring_index + @property + def sinogram_order(self) -> SinogramSpatialAxisOrder: + """the order of the sinogram axes""" + return self._sinogram_order + def _setup_plane_indices(self) -> None: """setup the start / end plane indices (similar to a Michelogram) """ @@ -855,21 +874,17 @@ def _setup_view_indices(self) -> None: def get_lor_indices( self, - views: None | npt.ArrayLike = None, - sinogram_order=SinogramSpatialAxisOrder.RVP - ) -> tuple[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike, npt.ArrayLike]: + views: None | Array = None + ) -> tuple[Array, Array, Array, Array]: """return the start and end indices of all LORs / or a subset of views Parameters ---------- - views : None | npt.ArrayLike, optional + views : None | Array, optional the views to consider, by default None means all views - sinogram_order : SinogramSpatialAxisOrder, optional - the order of the sinogram axes, by default SinogramSpatialAxisOrder.RVP - Returns ------- - start_mods, end_mods, start_inds, end_inds + start_mods, end_mods, start_inds, end_inds """ if views is None: @@ -896,17 +911,19 @@ def get_lor_indices( start_inds = self.xp.reshape(start_inds, sinogram_spatial_shape) end_inds = self.xp.reshape(end_inds, sinogram_spatial_shape) - if sinogram_order is not SinogramSpatialAxisOrder.PVR: - if sinogram_order is SinogramSpatialAxisOrder.RVP: + if self.sinogram_order is not SinogramSpatialAxisOrder.PVR: + if self.sinogram_order is SinogramSpatialAxisOrder.RVP: new_order = (2, 1, 0) - elif sinogram_order is SinogramSpatialAxisOrder.RPV: + elif self.sinogram_order is SinogramSpatialAxisOrder.RPV: new_order = (2, 0, 1) - elif sinogram_order is SinogramSpatialAxisOrder.VRP: + elif self.sinogram_order is SinogramSpatialAxisOrder.VRP: new_order = (1, 2, 0) - elif sinogram_order is SinogramSpatialAxisOrder.VPR: + elif self.sinogram_order is SinogramSpatialAxisOrder.VPR: new_order = (1, 0, 2) - elif sinogram_order is SinogramSpatialAxisOrder.PRV: + elif self.sinogram_order is SinogramSpatialAxisOrder.PRV: new_order = (0, 2, 1) + else: + new_order = (0, 1, 2) start_mods = self.xp.permute_dims(start_mods, new_order) end_mods = self.xp.permute_dims(end_mods, new_order) @@ -914,32 +931,31 @@ def get_lor_indices( start_inds = self.xp.permute_dims(start_inds, new_order) end_inds = self.xp.permute_dims(end_inds, new_order) - return start_mods, end_mods, start_inds, end_inds + return start_mods, end_mods, start_inds, end_inds def get_lor_coordinates( self, - views: None | npt.ArrayLike = None, - sinogram_order=SinogramSpatialAxisOrder.RVP - ) -> tuple[npt.ArrayLike, npt.ArrayLike]: + views: None | Array = None, + ) -> tuple[Array, Array]: """return the start and end coordinates of all LORs / or a subset of views Parameters ---------- - views : None | npt.ArrayLike, optional + views : None | Array, optional the views to consider, by default None means all views sinogram_order : SinogramSpatialAxisOrder, optional the order of the sinogram axes, by default SinogramSpatialAxisOrder.RVP Returns ------- - xstart, xend : npt.ArrayLike + xstart, xend : Array 2 dimensional floating point arrays containing the start and end coordinates of all LORs """ - start_mods, end_mods, start_inds, end_inds = self.get_lor_indices(views, sinogram_order) + start_mods, end_mods, start_inds, end_inds = self.get_lor_indices(views) sinogram_spatial_shape = start_mods.shape - + start_mods = self.xp.reshape(start_mods, (-1, )) start_inds = self.xp.reshape(start_inds, (-1, )) @@ -957,8 +973,8 @@ def get_lor_coordinates( def show_views(self, ax: plt.Axes, - views: npt.ArrayLike, - planes: npt.ArrayLike, + views: Array, + planes: Array, lw: float = 0.2, **kwargs) -> None: """show all LORs of a single view in a given plane @@ -975,8 +991,7 @@ def show_views(self, the line width, by default 0.2 """ - xs, xe = self.get_lor_coordinates( - views=views, sinogram_order=SinogramSpatialAxisOrder.RVP) + xs, xe = self.get_lor_coordinates(views=views) xs = self.xp.reshape(self.xp.take(xs, planes, axis=2), (-1, 3)) xe = self.xp.reshape(self.xp.take(xe, planes, axis=2), (-1, 3)) @@ -1015,8 +1030,8 @@ def __init__(self, lor_descriptor: RegularPolygonPETLORDescriptor, img_shape: tuple[int, int, int], voxel_size: tuple[float, float, float], - img_origin: None | npt.ArrayLike = None, - views: None | npt.ArrayLike = None, + img_origin: None | Array = None, + views: None | Array = None, resolution_model: None | parallelproj.LinearOperator = None, tof: bool = False): """Regular polygon PET projector @@ -1029,10 +1044,10 @@ def __init__(self, shape of the image to be projected voxel_size : tuple[float, float, float] the voxel size of the image to be projected - img_origin : None | npt.ArrayLike, optional - the origin of the image to be projected, by default None + img_origin : None | Array, optional + the origin of the image to be projected, by default None means that image is "centered" in the scanner - views : None | npt.ArrayLike, optional + views : None | Array, optional sinogram views to be projected, by default None means that all views are being projected resolution_model : None | parallelproj.LinearOperator, optional @@ -1068,8 +1083,7 @@ def __init__(self, self._resolution_model = resolution_model - self._xstart, self._xend = lor_descriptor.get_lor_coordinates( - views=self._views, sinogram_order=SinogramSpatialAxisOrder['RVP']) + self._xstart, self._xend = lor_descriptor.get_lor_coordinates(views=self._views) self._tof = tof self._tof_parameters = TOFParameters() @@ -1115,7 +1129,7 @@ def tof_parameters(self, value: TOFParameters) -> None: self._tof_parameters = value @property - def img_origin(self) -> npt.NDArray: + def img_origin(self) -> Array: return self._img_origin def _apply(self, x): @@ -1185,7 +1199,7 @@ def distributed_subset_order(n: int) -> list[int]: Returns ------- list[int] - """ + """ l = [x for x in range(n)] o = [] @@ -1196,4 +1210,3 @@ def distributed_subset_order(n: int) -> list[int]: o.append(l.pop(len(l)//2)) return o - diff --git a/python/utils.py b/python/utils.py new file mode 100644 index 0000000..113d0ab --- /dev/null +++ b/python/utils.py @@ -0,0 +1,79 @@ +import array_api_compat.numpy as np +import parallelproj_utils +from numpy.array_api._array_object import Array +from types import ModuleType + + +def noisy_tof_sinogram_to_lm( + noisy_sinogram: Array, + lor_descriptor: parallelproj_utils.PETLORDescriptor, + xp: ModuleType, + dev: str, +) -> tuple[Array, Array, Array]: + """convert a noisy sinogram to listmode data + + Parameters + ---------- + noisy_sinogram : Array + sinogram containing Poisson noise (integer values) + lor_descriptor : parallelproj_utils.PETLORDescriptor + description of the LOR geometry + xp : ModuleType + array module + dev : str + device + + Returns + ------- + tuple[Array, Array, Array] + event_det_id_1, event_det_id_2, event_tof_bin + """ + if noisy_sinogram.ndim != 4: + raise ValueError( + f"noisy_sinogram must be a 4D array, the last axis must be the TOF axis, but has shape {noisy_sinogram.shape}" + ) + + num_tof_bins = noisy_sinogram.shape[3] + + # ravel the noisy sinogram and the detector start and end "index" sinograms + noisy_sinogram = xp.reshape(noisy_sinogram, (noisy_sinogram.size,)) + + # get the two dimensional indices of all sinogram bins + start_mods, end_mods, start_inds, end_inds = lor_descriptor.get_lor_indices() + + # generate two sinograms that contain the linearized detector start and end indices + sino_det_start_index = ( + lor_descriptor.scanner.num_lor_endpoints_per_module[0] * start_mods + start_inds + ) + sino_det_end_index = ( + lor_descriptor.scanner.num_lor_endpoints_per_module[0] * end_mods + end_inds + ) + + sino_det_start_index = xp.reshape( + xp.stack(num_tof_bins * [sino_det_start_index], axis=-1), + sino_det_start_index.size * num_tof_bins, + ) + + sino_det_end_index = xp.reshape( + xp.stack(num_tof_bins * [sino_det_end_index], axis=-1), + sino_det_end_index.size * num_tof_bins, + ) + + # ---------------------------------------------------------------------------- + # ---------------------------------------------------------------------------- + # --- convert the index sinograms in to listmode data ------------------------ + # ---------------------------------------------------------------------------- + # ---------------------------------------------------------------------------- + + # generate listmode data from the noisy sinogram + event_sino_inds = np.repeat(np.arange(noisy_sinogram.shape[0]), noisy_sinogram) + # shuffle the event sinogram indices + np.random.shuffle(event_sino_inds) + # convert event sino indices to xp array + event_sino_inds = xp.asarray(event_sino_inds, device=dev) + + event_det_id_1 = xp.take(sino_det_start_index, event_sino_inds) + event_det_id_2 = xp.take(sino_det_end_index, event_sino_inds) + event_tof_bin = event_sino_inds % num_tof_bins + + return event_det_id_1, event_det_id_2, event_tof_bin