diff --git a/.gitignore b/.gitignore index a0232e3..57df96b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ ~$* *~ *.bak + +# user defined things +figs/* diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..2cecbf9 --- /dev/null +++ b/data/.gitignore @@ -0,0 +1,3 @@ +*.npy +*.npz +*.prd diff --git a/data/SL.npy b/data/SL.npy new file mode 100644 index 0000000..85b7aa8 Binary files /dev/null and b/data/SL.npy differ diff --git a/environment.yml b/environment.yml index a61c4e7..2c95e6e 100644 --- a/environment.yml +++ b/environment.yml @@ -10,7 +10,7 @@ dependencies: - h5py>=3.7.0 - hdf5>=1.12.1 - howardhinnant_date>=3.0.1 - - ipykernel>=6.19.2 + - ipykernel>=6.19.2 - ninja>=1.11.0 - nlohmann_json>=3.11.2 - numpy>=1.24.3 @@ -18,3 +18,6 @@ dependencies: - shellcheck>=0.8.0 - xtensor-fftw>=0.2.5 - xtensor>=0.24.2 + - parallelproj>=1.6.1 + - pytorch>=2.0 + - array-api-compat>=1.4 diff --git a/python/.gitignore b/python/.gitignore new file mode 100644 index 0000000..8d221a4 --- /dev/null +++ b/python/.gitignore @@ -0,0 +1,2 @@ +*.png +*.npy diff --git a/python/00_simulate_lm_data.py b/python/00_simulate_lm_data.py new file mode 100644 index 0000000..2378170 --- /dev/null +++ b/python/00_simulate_lm_data.py @@ -0,0 +1,223 @@ +from __future__ import annotations +import sys + +sys.path.append("../PETSIRD/python") + +import prd +import parallelproj +import utils +import array_api_compat.numpy as np +import matplotlib.pyplot as plt +from array_api_compat import to_device +from prd_io import write_prd_from_numpy_arrays +from pathlib import Path + +# ---------------------------------------------------------------- +# -- Choose you favorite array backend and device here ----------- +# ---------------------------------------------------------------- + +import numpy.array_api as xp + +dev: str = "cpu" + +# ---------------------------------------------------------------- +# ---------------------------------------------------------------- + +output_dir: str = "../data/sim_LM_acq_1" +output_sens_img_file: str = "sensitivity_image.npz" +output_prd_file: str = "simulated_lm.prd" +expected_num_trues: float = 1e6 + +np.random.seed(42) + +# create the output directory +Path(output_dir).mkdir(exist_ok=True, parents=True) + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- setup the scanner / LOR geometry --------------------------------------- +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# setup a line of response descriptor that describes the LOR start / endpoints of +# a "narrow" clinical PET scanner with 9 rings +lor_descriptor = utils.DemoPETScannerLORDescriptor( + xp, dev, num_rings=4, radial_trim=141 +) + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- setup a simple 3D test image ------------------------------------------- +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# image properties +voxel_size = (2.66, 2.66, 2.66) +num_trans = 128 +num_ax = 2 * lor_descriptor.scanner.num_modules + +# setup a box like test image +img_shape = (num_trans, num_trans, num_ax) +n0, n1, n2 = img_shape + +# setup an image containing a box + +img = xp.asarray( + np.tile(np.load("../data/SL.npy")[..., None], (1, 1, num_ax)), + device=dev, + dtype=xp.float32, +) +img[:, :, :2] = 0 +img[:, :, -2:] = 0 + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- setup a non-TOF projector and project ---------------------------------- +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +res_model = parallelproj.GaussianFilterOperator( + img_shape, sigma=4.5 / (2.355 * xp.asarray(voxel_size)) +) +projector = utils.RegularPolygonPETProjector( + lor_descriptor, img_shape, voxel_size, resolution_model=res_model +) +projector.tof = True # set this to True to get a time of flight projector + +# forward project the image +noise_free_sinogram = projector(img) + +# rescale the forward projection and image such that we get the expected number of trues +scale = expected_num_trues / float(xp.sum(noise_free_sinogram)) +noise_free_sinogram *= scale +img *= scale + +# calculate the sensitivity image +sens_img = projector.adjoint( + xp.ones(noise_free_sinogram.shape, device=dev, dtype=xp.float32) +) + +# add poisson noise to the noise free sinogram +noisy_sinogram = xp.asarray( + np.random.poisson(np.asarray(to_device(noise_free_sinogram, "cpu"))), device=dev +) +# ravel the noisy sinogram and the detector start and end "index" sinograms +noisy_sinogram = xp.reshape(noisy_sinogram, (noisy_sinogram.size,)) + +# get the two dimensional indices of all sinogram bins +start_mods, end_mods, start_inds, end_inds = lor_descriptor.get_lor_indices() + +# generate two sinograms that contain the linearized detector start and end indices +sino_det_start_index = ( + lor_descriptor.scanner.num_lor_endpoints_per_module[0] * start_mods + start_inds +) +sino_det_end_index = ( + lor_descriptor.scanner.num_lor_endpoints_per_module[0] * end_mods + end_inds +) + +# repeat number of TOF bin times here +num_tof_bins = projector.tof_parameters.num_tofbins + +sino_det_start_index = xp.reshape( + xp.stack(num_tof_bins * [sino_det_start_index], axis=-1), + sino_det_start_index.size * num_tof_bins, +) + +sino_det_end_index = xp.reshape( + xp.stack(num_tof_bins * [sino_det_end_index], axis=-1), + sino_det_end_index.size * num_tof_bins, +) + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- convert the index sinograms in to listmode data ------------------------ +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# generate listmode data from the noisy sinogram +event_sino_inds = np.repeat(np.arange(noisy_sinogram.shape[0]), noisy_sinogram) +# shuffle the event sinogram indices +np.random.shuffle(event_sino_inds) +# convert event sino indices to xp array +event_sino_inds = xp.asarray(event_sino_inds, device=dev) + +event_det_id_1 = xp.take(sino_det_start_index, event_sino_inds) +event_det_id_2 = xp.take(sino_det_end_index, event_sino_inds) +event_tof_bin = event_sino_inds % num_tof_bins + +print(f"number of simulated events: {event_det_id_1.shape[0]}") + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# ---- convert LM detector ID arrays into PRD here --------------------------- +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# get a lookup table that contains the world coordinates of all scanner detectors +# this is a 2D array of shape (num_detectors, 3) +scanner_lut = lor_descriptor.scanner.all_lor_endpoints + +# generate a list of detector coordinates of our scanner +detector_list = [] +for i in range(scanner_lut.shape[0]): + detector_list.append( + prd.Detector( + id=int(i), + x=float(scanner_lut[i, 0]), + y=float(scanner_lut[i, 1]), + z=float(scanner_lut[i, 2]), + ) + ) + +# setup the edges of all TOF bins +tof_bin_edges = ( + xp.arange(num_tof_bins + 1, dtype=xp.float32) - ((num_tof_bins + 1) / 2 - 0.5) +) * projector.tof_parameters.tofbin_width + +# setup the scanner information containing detector and TOF information +# WARNING: DEFINITION OF TOF RESOLUTION (sigma vs FWHM) not clear yet +scanner_information = prd.ScannerInformation( + model_name="DummyPET", + detectors=detector_list, + tof_bin_edges=np.asarray(to_device(tof_bin_edges, "cpu")), + tof_resolution=projector.tof_parameters.sigma_tof, +) + +# write the data to PETSIRD +write_prd_from_numpy_arrays( + event_det_id_1, + event_det_id_2, + scanner_information, + tof_idx_array=event_tof_bin, + output_file=str(Path(output_dir) / output_prd_file), +) +print(f"wrote PETSIRD LM file to {str(Path(output_dir) / output_prd_file)}") + +# HACK: write the sensitivity image to file +# this is currently needed since it is not agreed on how to store +# all valid detector pair combinations + attn / sens values in the PRD file +np.savez( + Path(output_dir) / output_sens_img_file, + sens_img=np.asarray(to_device(sens_img, "cpu")), + voxel_size=np.asarray(voxel_size), + img_origin=np.asarray(to_device(projector.img_origin, "cpu")), +) +print(f"wrote sensitivity image to {str(Path(output_dir) / output_sens_img_file)}") + +# ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- + +fig_dir = Path("../figs") +fig_dir.mkdir(exist_ok=True) + +vmax = 1.2 * xp.max(img) +fig, ax = plt.subplots(1, img.shape[2], figsize=(img.shape[2] * 2, 2)) +for i in range(img.shape[2]): + ax[i].imshow( + xp.asarray(to_device(img[:, :, i], "cpu")), vmin=0, vmax=vmax, cmap="Greys" + ) + ax[i].set_title(f"ground truth sl {i+1}", fontsize="small") + +fig.tight_layout() +fig.savefig(fig_dir / "simulated_phantom.png") diff --git a/python/01_reconstruct_lm_data.py b/python/01_reconstruct_lm_data.py new file mode 100644 index 0000000..d447255 --- /dev/null +++ b/python/01_reconstruct_lm_data.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import parallelproj +from array_api_compat import to_device +import matplotlib.pyplot as plt +from prd_io import read_prd_to_numpy_arrays +from pathlib import Path + +import array_api_compat.numpy as np + +# ---------------------------------------------------------------- +# -- Choose you favorite array backend and device here -i--------- +# ---------------------------------------------------------------- + +import numpy.array_api as xp + +dev = "cpu" + +# ---------------------------------------------------------------- +# ---------------------------------------------------------------- + +lm_data_dir: str = "../data/sim_LM_acq_1" +sens_img_file: str = "sensitivity_image.npz" +prd_file: str = "simulated_lm.prd" + +num_iter: int = 2 +num_subsets: int = 20 + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# read the LM file header and all event attributes +header, event_attributes = read_prd_to_numpy_arrays( + str(Path(lm_data_dir) / prd_file), xp, dev, read_tof=None, read_energy=False +) + +# read the detector coordinates into a 2D lookup table +scanner_lut = xp.asarray( + [[det.x, det.y, det.z] for det in header.scanner.detectors], + dtype=xp.float32, + device=dev, +) + +xstart = xp.take(scanner_lut, event_attributes[:, 0], axis=0) +xend = xp.take(scanner_lut, event_attributes[:, 1], axis=0) + +# check if we have TOF data and generate the corresponding TOF parameters we need for the +# TOF joseph projector +if event_attributes.shape[1] == 3: + tof = True + event_tof_bin = event_attributes[:, 2] + num_tof_bins = header.scanner.tof_bin_edges.shape[0] - 1 + tofbin_width = header.scanner.tof_bin_edges[1] - header.scanner.tof_bin_edges[0] + sigma_tof = xp.asarray([header.scanner.tof_resolution], dtype=xp.float32) + tofcenter_offset = xp.asarray([0], dtype=xp.float32) + nsigmas = 3.0 + print(f"read {event_attributes.shape[0]} TOF events") +else: + tof = False + print(f"read {event_attributes.shape[0]} non-TOF events") + +# HACK: write the sensitivity image to file +# this is currently needed since it is not agreed on how to store +# all valid detector pair combinations + attn / sens values in the PRD file +sens_img_data = np.load(Path(lm_data_dir) / sens_img_file) +sens_img = xp.asarray(sens_img_data["sens_img"], device=dev) +img_shape = sens_img.shape +voxel_size = xp.asarray(sens_img_data["voxel_size"], device=dev) +img_origin = xp.asarray(sens_img_data["img_origin"], device=dev) + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# ---- LM recon using the event detector IDs and the scanner LUT ------------- +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +res_model = parallelproj.GaussianFilterOperator( + img_shape, sigma=4.5 / (2.355 * xp.asarray(voxel_size)) +) + +recon = xp.ones(img_shape, dtype=xp.float32, device=dev) + +for it in range(num_iter): + for isub in range(num_subsets): + print(f"it {(it+1):03} / ss {(isub+1):03}", end="\r") + xs_sub = xstart[isub::num_subsets, :] + xe_sub = xend[isub::num_subsets, :] + + recon_sm = res_model(recon) + + if tof: + event_tof_bin_sub = event_tof_bin[isub::num_subsets] - num_tof_bins // 2 + exp = parallelproj.joseph3d_fwd_tof_lm( + xs_sub, + xe_sub, + recon, + img_origin, + voxel_size, + tofbin_width, + sigma_tof, + tofcenter_offset, + nsigmas, + event_tof_bin_sub, + ) + + ratio_back = parallelproj.joseph3d_back_tof_lm( + xs_sub, + xe_sub, + img_shape, + img_origin, + voxel_size, + 1 / exp, + tofbin_width, + sigma_tof, + tofcenter_offset, + nsigmas, + event_tof_bin_sub, + ) + else: + exp = parallelproj.joseph3d_fwd( + xs_sub, xe_sub, recon_sm, img_origin, voxel_size + ) + ratio_back = parallelproj.joseph3d_back( + xs_sub, xe_sub, img_shape, img_origin, voxel_size, 1 / exp + ) + + ratio_back_sm = res_model.adjoint(ratio_back) + + recon *= ratio_back_sm / (sens_img / num_subsets) + + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +fig_dir = Path("../figs") +fig_dir.mkdir(exist_ok=True) + +vmax = 0.055 +fig, ax = plt.subplots(1, recon.shape[2], figsize=(recon.shape[2] * 2, 2)) +for i in range(recon.shape[2]): + ax[i].imshow( + xp.asarray(to_device(recon[:, :, i], "cpu")), vmin=0, vmax=vmax, cmap="Greys" + ) + ax[i].set_title(f"LM recon sl {i+1}", fontsize="small") + +fig.tight_layout() +fig.savefig(fig_dir / "lm_reconstruction.png") diff --git a/python/parallelproj_sim.py b/python/parallelproj_sim.py new file mode 100644 index 0000000..e04777a --- /dev/null +++ b/python/parallelproj_sim.py @@ -0,0 +1,216 @@ +#TODO: - additive MLEM + +from __future__ import annotations + +import parallelproj +import utils +import array_api_compat.numpy as np +import matplotlib.pyplot as plt +from array_api_compat import to_device +from scipy.ndimage import gaussian_filter + +# device variable (cpu or cuda) that determines whether calculations +# are performed on the cpu or cuda gpu + +dev = "cpu" +expected_num_trues = 1e6 +num_iter = 3 +num_subsets = 20 +np.random.seed(1) + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- setup the scanner / LOR geometry --------------------------------------- +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# setup a line of response descriptor that describes the LOR start / endpoints of +# a "narrow" clinical PET scanner with 9 rings +lor_descriptor = utils.DemoPETScannerLORDescriptor( + np, dev, num_rings=4, radial_trim=141 +) + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- setup a simple 3D test image ------------------------------------------- +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# image properties +voxel_size = (2.66, 2.66, 2.66) +num_trans = 140 +num_ax = 2 * lor_descriptor.scanner.num_modules + +# setup a box like test image +img_shape = (num_trans, num_trans, num_ax) +n0, n1, n2 = img_shape + +# setup an image containing a box +img = np.zeros(img_shape, dtype=np.float32, device=dev) +img[(n0 // 4) : (3 * n0 // 4), (n1 // 4) : (3 * n1 // 4), 2:-2] = 1 +img[(7*n0 // 16) : (9 * n0 // 16), (6*n1 // 16) : (8 * n1 // 16), 2:-2] = 2. + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- setup a non-TOF projector and project ---------------------------------- +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# setup a simple image-based resolution model using 4.5mm FWHM Gaussian smoothing +res_model = parallelproj.GaussianFilterOperator(img_shape, sigma=4.5 / (2.355 * np.asarray(voxel_size))) +projector = utils.RegularPolygonPETProjector(lor_descriptor, img_shape, voxel_size, resolution_model=res_model) +projector.tof = False # set this to True to get a time of flight projector + +# forward project the image +noise_free_sinogram = projector(img) + +# rescale the forward projection and image such that we get the expected number of trues +scale = expected_num_trues / np.sum(noise_free_sinogram) +noise_free_sinogram *= scale +img *= scale + +# calculate the sensitivity image +sens_img = projector.adjoint(np.ones(noise_free_sinogram.shape, device=dev, dtype=np.float32)) + +# get the two dimensional indices of all sinogram bins +start_mods, end_mods, start_inds, end_inds = lor_descriptor.get_lor_indices() + +# generate two sinograms that contain the linearized detector start and end indices +sino_det_start_index = lor_descriptor.scanner.num_lor_endpoints_per_module[0] * start_mods + start_inds +sino_det_end_index = lor_descriptor.scanner.num_lor_endpoints_per_module[0] * end_mods + end_inds + +# add poisson noise to the noise free sinogram +noisy_sinogram = np.random.poisson(noise_free_sinogram) + +# ravel the noisy sinogram and the detector start and end "index" sinograms +noisy_sinogram = noisy_sinogram.ravel() +sino_det_start_index = sino_det_start_index.ravel() +sino_det_end_index = sino_det_end_index.ravel() + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- convert the index sinograms in to listmode data ------------------------ +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# generate listmode data from the noisy sinogram +event_sino_inds = np.repeat(np.arange(noisy_sinogram.shape[0]), noisy_sinogram) + +## generate timestamps +#acquisition_length = 20 # say, in mins +#timestamps_in_lors = np.array([np.sort(np.random.uniform(0, acquisition_length, +# size=noisy_sinogram[l])) for l in range(len(noisy_sinogram))]) + +# shuffle the event sinogram indices +np.random.shuffle(event_sino_inds) + +## assign timestamps for events - need one forward run over event_sino_inds +#timestamps_iter_table = np.zeros_like(noisy_sinogram, dtype=np.int64) # possibly many counts +#timestamps_in_events = np.zeros_like(noisy_sinogram, dtype=np.float32) +#for ind in event_sino_inds: # sorry, this is slow and ugly +# timestamps_in_events[ind] = timestamps_in_lors[ind, timestamps_iter_table[ind]] +# timestamps_iter_table[ind] += 1 + +## at this stage - lors are shuffled but timestamps are out of sequential order +## need to sort globally event_sino_inds according to timestamps +#evend_sino_inds = event_sino_inds[np.argsort(timestamps_in_events)] + + +event_det_id_1 = sino_det_start_index[event_sino_inds] +event_det_id_2 = sino_det_end_index[event_sino_inds] + +print(f'number of events: {event_det_id_1.shape[0]}') + +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- +#---- convert LM detector ID arrays into PRD here --------------------------- +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- + +# get a lookup table that contains the world coordinates of all scanner detectors +# this is a 2D array of shape (num_detectors, 3) +scanner_lut = lor_descriptor.scanner.all_lor_endpoints + +# +# +# +# +# +# + +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- +#---- read events back from PRD here ---------------------------------------- +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- + +# +# +# +# +# +# + +# hack until we have the reader / writer implemented +xstart = scanner_lut[event_det_id_1, :] +xend = scanner_lut[event_det_id_2, :] + +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- +#---- LM recon using the event detector IDs and the scanner LUT ------------- +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- + +recon = np.ones(img_shape, dtype=np.float32, device=dev) + +for it in range(num_iter): + for isub in range(num_subsets): + print(f'it {(it+1):03} / ss {(isub+1):03}', end='\r') + xs_sub = xstart[isub::num_subsets,:] + xe_sub = xend[isub::num_subsets,:] + + recon_sm = res_model(recon) + + exp = parallelproj.joseph3d_fwd(xs_sub, xe_sub, recon_sm, projector.img_origin, voxel_size) + ratio_back = parallelproj.joseph3d_back(xs_sub, xe_sub, img_shape, projector.img_origin, voxel_size, 1/exp) + + ratio_back_sm = res_model.adjoint(ratio_back) + + recon *= (ratio_back_sm / (sens_img / num_subsets)) + +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- +# show the scanner geometry and one view in one sinogram plane +fig = plt.figure(figsize=(8, 8)) +ax = fig.add_subplot(projection="3d") +lor_descriptor.scanner.show_lor_endpoints(ax, show_linear_index=True, annotation_fontsize=4) + +# plot the LORs of the first n events +for i, col in enumerate(('red','green','blue')): + xs = scanner_lut[event_det_id_1[i], :] + xe= scanner_lut[event_det_id_2[i], :] + + ax.plot( + [xs[0], xe[0]], + [xs[1], xe[1]], + [xs[2], xe[2]], + color=col, + linewidth=1., + ) + +fig.tight_layout() +fig.show() + +vmax = 1.2*img.max() +fig2, ax2 = plt.subplots(3, recon.shape[2], figsize=(recon.shape[2] * 2, 3 * 2)) +for i in range(recon.shape[2]): + ax2[0,i].imshow(np.asarray(to_device(img[:, :, i], "cpu")), vmin = 0, vmax = vmax, cmap = 'Greys') + ax2[1,i].imshow(np.asarray(to_device(recon[:, :, i], "cpu")), vmin = 0, vmax = vmax, cmap = 'Greys') + ax2[2,i].imshow(gaussian_filter(np.asarray(to_device(recon[:, :, i], "cpu")), 1.5), vmin = 0, vmax = vmax, cmap = 'Greys') + + ax2[0,i].set_title(f'ground truth sl {i+1}', fontsize = 'small') + ax2[1,i].set_title(f'LM recon {i+1}', fontsize = 'small') + ax2[2,i].set_title(f'LM recon smoothed {i+1}', fontsize = 'small') + +fig2.tight_layout() +fig2.show() diff --git a/python/parallelproj_sim_tof.py b/python/parallelproj_sim_tof.py new file mode 100644 index 0000000..9b8d891 --- /dev/null +++ b/python/parallelproj_sim_tof.py @@ -0,0 +1,232 @@ +#TODO: - additive MLEM + +from __future__ import annotations + +import parallelproj +import utils +import array_api_compat.numpy as np +import matplotlib.pyplot as plt +from array_api_compat import to_device +from scipy.ndimage import gaussian_filter + +# device variable (cpu or cuda) that determines whether calculations +# are performed on the cpu or cuda gpu + +dev = "cpu" +expected_num_trues = 1e6 +num_iter = 3 +num_subsets = 20 +np.random.seed(1) + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- setup the scanner / LOR geometry --------------------------------------- +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# setup a line of response descriptor that describes the LOR start / endpoints of +# a "narrow" clinical PET scanner with 9 rings +lor_descriptor = utils.DemoPETScannerLORDescriptor( + np, dev, num_rings=4, radial_trim=141 +) + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- setup a simple 3D test image ------------------------------------------- +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# image properties +voxel_size = (2.66, 2.66, 2.66) +num_trans = 140 +num_ax = 2 * lor_descriptor.scanner.num_modules + +# setup a box like test image +img_shape = (num_trans, num_trans, num_ax) +n0, n1, n2 = img_shape + +# setup an image containing a box +img = np.zeros(img_shape, dtype=np.float32, device=dev) +img[(n0 // 4) : (3 * n0 // 4), (n1 // 4) : (3 * n1 // 4), 2:-2] = 1 +img[(7*n0 // 16) : (9 * n0 // 16), (6*n1 // 16) : (8 * n1 // 16), 2:-2] = 2. + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- setup a non-TOF projector and project ---------------------------------- +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# setup a simple image-based resolution model using 4.5mm FWHM Gaussian smoothing +res_model = parallelproj.GaussianFilterOperator(img_shape, sigma=4.5 / (2.355 * np.asarray(voxel_size))) +projector = utils.RegularPolygonPETProjector(lor_descriptor, img_shape, voxel_size, resolution_model=res_model) +projector.tof = True # set this to True to get a time of flight projector + +# forward project the image +noise_free_sinogram = projector(img) + +# rescale the forward projection and image such that we get the expected number of trues +scale = expected_num_trues / np.sum(noise_free_sinogram) +noise_free_sinogram *= scale +img *= scale +# +# calculate the sensitivity image +sens_img = projector.adjoint(np.ones(noise_free_sinogram.shape, device=dev, dtype=np.float32)) + +# get the two dimensional indices of all sinogram bins +start_mods, end_mods, start_inds, end_inds = lor_descriptor.get_lor_indices() + +# generate two sinograms that contain the linearized detector start and end indices +sino_det_start_index = lor_descriptor.scanner.num_lor_endpoints_per_module[0] * start_mods + start_inds +sino_det_end_index = lor_descriptor.scanner.num_lor_endpoints_per_module[0] * end_mods + end_inds + +# add poisson noise to the noise free sinogram +noisy_sinogram = np.random.poisson(noise_free_sinogram) + +## ravel the noisy sinogram and the detector start and end "index" sinograms +noisy_sinogram = noisy_sinogram.ravel() + +# repeat number of TOF bin times here +num_tof_bins = projector.tof_parameters.num_tofbins +sino_det_start_index = np.repeat(sino_det_start_index.ravel(), num_tof_bins) +sino_det_end_index = np.repeat(sino_det_end_index.ravel(), num_tof_bins) + +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- +# --- convert the index sinograms in to listmode data ------------------------ +# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- + +# generate listmode data from the noisy sinogram +event_sino_inds = np.repeat(np.arange(noisy_sinogram.shape[0]), noisy_sinogram) + +## generate timestamps +#acquisition_length = 20 # say, in mins +#timestamps_in_lors = np.array([np.sort(np.random.uniform(0, acquisition_length, +# size=noisy_sinogram[l])) for l in range(len(noisy_sinogram))]) + +# shuffle the event sinogram indices +np.random.shuffle(event_sino_inds) + +## assign timestamps for events - need one forward run over event_sino_inds +#timestamps_iter_table = np.zeros_like(noisy_sinogram, dtype=np.int64) # possibly many counts +#timestamps_in_events = np.zeros_like(noisy_sinogram, dtype=np.float32) +#for ind in event_sino_inds: # sorry, this is slow and ugly +# timestamps_in_events[ind] = timestamps_in_lors[ind, timestamps_iter_table[ind]] +# timestamps_iter_table[ind] += 1 + +## at this stage - lors are shuffled but timestamps are out of sequential order +## need to sort globally event_sino_inds according to timestamps +#evend_sino_inds = event_sino_inds[np.argsort(timestamps_in_events)] + +event_det_id_1 = sino_det_start_index[event_sino_inds] +event_det_id_2 = sino_det_end_index[event_sino_inds] +event_tof_bin = event_sino_inds % num_tof_bins + +print(f'number of events: {event_det_id_1.shape[0]}') + +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- +#---- convert LM detector ID arrays into PRD here --------------------------- +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- + +# get a lookup table that contains the world coordinates of all scanner detectors +# this is a 2D array of shape (num_detectors, 3) +scanner_lut = lor_descriptor.scanner.all_lor_endpoints + +# +# +# +# +# +# + +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- +#---- read events back from PRD here ---------------------------------------- +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- + +# +# +# +# +# +# + +# hack until we have the reader / writer implemented +xstart = scanner_lut[event_det_id_1, :] +xend = scanner_lut[event_det_id_2, :] +event_tof_bin = event_tof_bin +# tof bin width in mm +tofbin_width = projector.tof_parameters.tofbin_width +# sigma of the Gaussian TOF kernel in mm +sigma_tof = np.array([projector.tof_parameters.sigma_tof], dtype = np.float32) +tofcenter_offset = np.array([projector.tof_parameters.tofcenter_offset], dtype = np.float32) +nsigmas = projector.tof_parameters.num_sigmas + +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- +#---- LM recon using the event detector IDs and the scanner LUT ------------- +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- + +recon = np.ones(img_shape, dtype=np.float32, device=dev) + +for it in range(num_iter): + for isub in range(num_subsets): + print(f'it {(it+1):03} / ss {(isub+1):03}', end='\r') + xs_sub = xstart[isub::num_subsets,:] + xe_sub = xend[isub::num_subsets,:] + # in parallelproj the "0" TOF bin is the central TOF bin + # in PETSIRD the TOFbin number is non-negative + event_tof_bin_sub = event_tof_bin[isub::num_subsets] - num_tof_bins // 2 + + recon_sm = res_model(recon) + + exp = parallelproj.joseph3d_fwd_tof_lm(xs_sub, xe_sub, recon, projector.img_origin, voxel_size, tofbin_width, + sigma_tof, tofcenter_offset, nsigmas, event_tof_bin_sub) + + ratio_back = parallelproj.joseph3d_back_tof_lm(xs_sub, xe_sub, img_shape, projector.img_origin, + voxel_size, 1/exp, tofbin_width, sigma_tof, tofcenter_offset, nsigmas, event_tof_bin_sub) + + ratio_back_sm = res_model.adjoint(ratio_back) + + recon *= (ratio_back_sm / (sens_img / num_subsets)) + +#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- +# show the scanner geometry and one view in one sinogram plane +fig = plt.figure(figsize=(8, 8)) +ax = fig.add_subplot(projection="3d") +lor_descriptor.scanner.show_lor_endpoints(ax, show_linear_index=True, annotation_fontsize=4) + +# plot the LORs of the first n events +for i, col in enumerate(('red','green','blue')): + xs = scanner_lut[event_det_id_1[i], :] + xe= scanner_lut[event_det_id_2[i], :] + + ax.plot( + [xs[0], xe[0]], + [xs[1], xe[1]], + [xs[2], xe[2]], + color=col, + linewidth=1., + ) + +fig.tight_layout() +fig.show() + +vmax = 1.2*img.max() +fig2, ax2 = plt.subplots(3, recon.shape[2], figsize=(recon.shape[2] * 2, 3 * 2)) +for i in range(recon.shape[2]): + ax2[0,i].imshow(np.asarray(to_device(img[:, :, i], "cpu")), vmin = 0, vmax = vmax, cmap = 'Greys') + ax2[1,i].imshow(np.asarray(to_device(recon[:, :, i], "cpu")), vmin = 0, vmax = vmax, cmap = 'Greys') + ax2[2,i].imshow(gaussian_filter(np.asarray(to_device(recon[:, :, i], "cpu")), 1.5), vmin = 0, vmax = vmax, cmap = 'Greys') + + ax2[0,i].set_title(f'ground truth sl {i+1}', fontsize = 'small') + ax2[1,i].set_title(f'LM recon {i+1}', fontsize = 'small') + ax2[2,i].set_title(f'LM recon smoothed {i+1}', fontsize = 'small') + +fig2.tight_layout() +fig2.show() diff --git a/python/prd_io.py b/python/prd_io.py new file mode 100644 index 0000000..3836ddf --- /dev/null +++ b/python/prd_io.py @@ -0,0 +1,180 @@ +from __future__ import annotations +import sys + +sys.path.append("../PETSIRD/python") +import prd + +from numpy.array_api._array_object import Array +from types import ModuleType + + +def write_prd_from_numpy_arrays( + detector_1_id_array: Array, + detector_2_id_array: Array, + scanner_information: prd.ScannerInformation, + tof_idx_array: Array | None = None, + energy_1_idx_array: Array | None = None, + energy_2_idx_array: Array | None = None, + output_file: str | None = None, +) -> None: + """Write a PRD file from numpy arrays. Currently all into one time block + + Parameters + ---------- + detector_1_id_array : Array + array containing the detector 1 id for each event + detector_2_id_array : Array + array containing the detector 2 id for each event + scanner_information : prd.ScannerInformation + description of the scanner according to PETSIRD + (e.g. including all detector coordinates) + tof_idx_array : Array | None, optional + array containing the tof bin index of each event + energy_1_idx_array : Array | None, optional + array containing the energy 1 index of each event + energy_2_idx_array : Array | None, optional + array containing the energy 2 index of each event + output_file : str | None, optional + output file, if None write to stdout + """ + + num_events: int = detector_1_id_array.size + + events = [] + for i in range(num_events): + det_id_1 = int(detector_1_id_array[i]) + det_id_2 = int(detector_2_id_array[i]) + + if tof_idx_array is not None: + tof_idx = int(tof_idx_array[i]) + else: + tof_idx = 0 + + if energy_1_idx_array is not None: + energy_1_idx = int(energy_1_idx_array[i]) + else: + energy_1_idx = 0 + + if energy_2_idx_array is not None: + energy_2_idx = int(energy_2_idx_array[i]) + else: + energy_2_idx = 0 + + events.append( + prd.CoincidenceEvent( + detector_1_id=det_id_1, + detector_2_id=det_id_2, + tof_idx=tof_idx, + energy_1_idx=energy_1_idx, + energy_2_idx=energy_2_idx, + ) + ) + + time_block = prd.TimeBlock(id=0, prompt_events=events) + + if output_file is None: + with prd.BinaryPrdExperimentWriter(sys.stdout.buffer) as writer: + writer.write_header(prd.Header(scanner=scanner_information)) + writer.write_time_blocks((time_block,)) + else: + if output_file.endswith(".ndjson"): + with prd.NDJsonPrdExperimentWriter(output_file) as writer: + writer.write_header(prd.Header(scanner=scanner_information)) + writer.write_time_blocks((time_block,)) + else: + with prd.BinaryPrdExperimentWriter(output_file) as writer: + writer.write_header(prd.Header(scanner=scanner_information)) + writer.write_time_blocks((time_block,)) + + +def read_prd_to_numpy_arrays( + prd_file: str, + xp: ModuleType, + dev: str, + read_tof: bool | None = None, + read_energy: bool | None = None, +) -> tuple[prd.types.Header, Array]: + """Read all time blocks of a PETSIRD listmode file + + Parameters + ---------- + prd_file : str + the PETSIRD listmode file + xp : ModuleType + the array backend module + dev : str + device used for the returned arrays + read_tof : bool | None, optional + read the TOF bin information of every event + default None means that is is auto determined + based on the scanner information (length of tof bin edges) + read_energy : bool | None, optional + read the energy information of every event + default None means that is is auto determined + based on the scanner information (length of energy bin edges) + + Returns + ------- + tuple[prd.types.Header, Array] + PRD listmode file header, 2D array containing all event attributes + """ + with prd.BinaryPrdExperimentReader(prd_file) as reader: + # Read header and build lookup table + header = reader.read_header() + + # bool that decides whether the scanner has TOF and whether it is + # meaningful to read TOF + if read_tof is None: + r_tof: bool = len(header.scanner.tof_bin_edges) > 1 + else: + r_tof = read_tof + + # bool that decides whether the scanner has energy and whether it is + # meaningful to read energy + if read_energy is None: + r_energy: bool = len(header.scanner.energy_bin_edges) > 1 + else: + r_energy = read_energy + + # loop over all time blocks and read all meaningful event attributes + for time_block in reader.read_time_blocks(): + if r_tof and r_energy: + event_attribute_list = [ + [ + e.detector_1_id, + e.detector_2_id, + e.tof_idx, + e.energy_1_idx, + e.energy_2_idx, + ] + for e in time_block.prompt_events + ] + elif r_tof and (not r_energy): + event_attribute_list = [ + [ + e.detector_1_id, + e.detector_2_id, + e.tof_idx, + ] + for e in time_block.prompt_events + ] + elif (not r_tof) and r_energy: + event_attribute_list = [ + [ + e.detector_1_id, + e.detector_2_id, + e.energy_1_idx, + e.energy_2_idx, + ] + for e in time_block.prompt_events + ] + else: + event_attribute_list = [ + [ + e.detector_1_id, + e.detector_2_id, + ] + for e in time_block.prompt_events + ] + + return header, xp.asarray(event_attribute_list, device=dev) diff --git a/python/utils.py b/python/utils.py new file mode 100644 index 0000000..1953103 --- /dev/null +++ b/python/utils.py @@ -0,0 +1,1199 @@ +from __future__ import annotations + +import enum +import abc +from dataclasses import dataclass +import array_api_compat.numpy as np +import numpy.typing as npt +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d.art3d import Line3DCollection + +from types import ModuleType +from array_api_compat import device, to_device, size + +import parallelproj + + +@dataclass +class TOFParameters: + """ + generic time of flight (TOF) parameters for a scanner with 385ps FWHM TOF + + num_tofbins: int + number of time of flight bins + tofbin_width: float + width of the TOF bin in spatial units (mm) + sigma_tof: float + standard deviation of Gaussian TOF kernel in spatial units (mm) + num_sigmas: float + number of sigmas after which TOF kernel is truncated + tofcenter_offset: float + offset of center of central TOF bin from LOR center in spatial units (mm) + """ + num_tofbins: int = 29 + tofbin_width: float = 13 * 0.01302 * 299.792 / 2 # 13 TOF "small" TOF bins of 0.01302[ns] * (speed of light / 2) [mm/ns] + sigma_tof: float = (299.792 / 2) * ( + 0.385 / 2.355) # (speed_of_light [mm/ns] / 2) * TOF FWHM [ns] / 2.355 + num_sigmas: float = 3. + tofcenter_offset: float = 0 + + +class SinogramSpatialAxisOrder(enum.Enum): + """order of spatial axis in a sinogram R (radial), V (view), P (plane)""" + + RVP = enum.auto() + """[radial,view,plane]""" + RPV = enum.auto() + """[radial,plane,view]""" + VRP = enum.auto() + """[view,radial,plane]""" + VPR = enum.auto() + """[view,plane,radial]""" + PRV = enum.auto() + """[plane,radial,view]""" + PVR = enum.auto() + """[plane,view,radial]""" + + +class PETScannerModule(abc.ABC): + + def __init__( + self, + xp: ModuleType, + dev: str, + num_lor_endpoints: int, + affine_transformation_matrix: npt.NDArray | None = None) -> None: + """abstract base class for PET scanner module + + Parameters + ---------- + xp: ModuleType + array module to use for storing the LOR endpoints + dev: str + device to use for storing the LOR endpoints + num_lor_endpoints : int + number of LOR endpoints in the module + affine_transformation_matrix : npt.NDArray | None, optional + 4x4 affine transformation matrix applied to the LOR endpoint coordinates, default None + if None, the 4x4 identity matrix is used + """ + + self._xp = xp + self._dev = dev + self._num_lor_endpoints = num_lor_endpoints + self._lor_endpoint_numbers = self.xp.arange(num_lor_endpoints, + device=self.dev) + + if affine_transformation_matrix is None: + self._affine_transformation_matrix = self.xp.eye(4, + device=self.dev) + else: + self._affine_transformation_matrix = affine_transformation_matrix + + @property + def xp(self) -> ModuleType: + """array module to use for storing the LOR endpoints""" + return self._xp + + @property + def dev(self) -> str: + """device to use for storing the LOR endpoints""" + return self._dev + + @property + def num_lor_endpoints(self) -> int: + """total number of LOR endpoints in the module + + Returns + ------- + int + """ + return self._num_lor_endpoints + + @property + def lor_endpoint_numbers(self) -> npt.NDArray: + """array enumerating all the LOR endpoints in the module + + Returns + ------- + npt.NDArray + """ + return self._lor_endpoint_numbers + + @property + def affine_transformation_matrix(self) -> npt.NDArray: + """4x4 affine transformation matrix + + Returns + ------- + npt.NDArray + """ + return self._affine_transformation_matrix + + @abc.abstractmethod + def get_raw_lor_endpoints(self, + inds: npt.NDArray | None = None) -> npt.NDArray: + """mapping from LOR endpoint indices within module to an array of "raw" world coordinates + + Parameters + ---------- + inds : npt.NDArray | None, optional + an non-negative integer array of indices, default None + if None means all possible indices [0, ... , num_lor_endpoints - 1] + + Returns + ------- + npt.NDArray + a 3 x len(inds) float array with the world coordinates of the LOR endpoints + """ + if inds is None: + inds = self.lor_endpoint_numbers + raise NotImplementedError + + def get_lor_endpoints(self, + inds: npt.NDArray | None = None) -> npt.NDArray: + """mapping from LOR endpoint indices within module to an array of "transformed" world coordinates + + Parameters + ---------- + inds : npt.NDArray | None, optional + an non-negative integer array of indices, default None + if None means all possible indices [0, ... , num_lor_endpoints - 1] + + Returns + ------- + npt.NDArray + a 3 x len(inds) float array with the world coordinates of the LOR endpoints including an affine transformation + """ + + raw_lor_endpoints = self.get_raw_lor_endpoints(inds) + + tmp = self.xp.ones((raw_lor_endpoints.shape[0], 4), device=self.dev) + tmp[:, :-1] = raw_lor_endpoints + + return (tmp @ self.affine_transformation_matrix.T)[:, :3] + + def show_lor_endpoints(self, + ax: plt.Axes, + annotation_fontsize: float = 0, + annotation_prefix: str = '', + annotation_offset: int = 0, + transformed: bool = True, + **kwargs) -> None: + """show the LOR coordinates in a 3D scatter plot + + Parameters + ---------- + ax : plt.Axes + 3D matplotlib axes + annotation_fontsize : float, optional + fontsize of LOR endpoint number annotation, by default 0 + annotation_prefix : str, optional + prefix for annotation, by default '' + annotation_offset : int, optional + number to add to crystal number, by default 0 + transformed : bool, optional + use transformed instead of raw coordinates, by default True + """ + + if transformed: + all_lor_endpoints = self.get_lor_endpoints() + else: + all_lor_endpoints = self.get_raw_lor_endpoints() + + # convert to numpy array + all_lor_endpoints = np.asarray(to_device(all_lor_endpoints, 'cpu')) + + ax.scatter(all_lor_endpoints[:, 0], all_lor_endpoints[:, 1], + all_lor_endpoints[:, 2], **kwargs) + + ax.set_box_aspect([ + ub - lb for lb, ub in (getattr(ax, f'get_{a}lim')() for a in 'xyz') + ]) + + ax.set_xlabel('x0') + ax.set_ylabel('x1') + ax.set_zlabel('x2') + + if annotation_fontsize > 0: + for i in self.lor_endpoint_numbers: + ax.text(all_lor_endpoints[i, 0], + all_lor_endpoints[i, 1], + all_lor_endpoints[i, 2], + f'{annotation_prefix}{i+annotation_offset}', + fontsize=annotation_fontsize) + + +class RegularPolygonPETScannerModule(PETScannerModule): + + def __init__( + self, + xp: ModuleType, + dev: str, + radius: float, + num_sides: int, + num_lor_endpoints_per_side: int, + lor_spacing: float, + ax0: int = 2, + ax1: int = 1, + affine_transformation_matrix: npt.NDArray | None = None) -> None: + """regular Polygon PET scanner module + + Parameters + ---------- + xp: ModuleType + array module to use for storing the LOR endpoints + device: str + device to use for storing the LOR endpoints + radius : float + inner radius of the regular polygon + num_sides: int + number of sides of the regular polygon + num_lor_endpoints_per_sides: int + number of LOR endpoints per side + lor_spacing : float + spacing between the LOR endpoints in the polygon direction + ax0 : int, optional + axis number for the first direction, by default 2 + ax1 : int, optional + axis number for the second direction, by default 1 + affine_transformation_matrix : npt.NDArray | None, optional + 4x4 affine transformation matrix applied to the LOR endpoint coordinates, default None + if None, the 4x4 identity matrix is used + """ + + self._radius = radius + self._num_sides = num_sides + self._num_lor_endpoints_per_side = num_lor_endpoints_per_side + self._ax0 = ax0 + self._ax1 = ax1 + self._lor_spacing = lor_spacing + super().__init__(xp, dev, num_sides * num_lor_endpoints_per_side, + affine_transformation_matrix) + + @property + def radius(self) -> float: + """inner radius of the regular polygon + + Returns + ------- + float + """ + return self._radius + + @property + def num_sides(self) -> int: + """number of sides of the regular polygon + + Returns + ------- + int + """ + return self._num_sides + + @property + def num_lor_endpoints_per_side(self) -> int: + """number of LOR endpoints per side + + Returns + ------- + int + """ + return self._num_lor_endpoints_per_side + + @property + def ax0(self) -> int: + """axis number for the first module direction + + Returns + ------- + int + """ + return self._ax0 + + @property + def ax1(self) -> int: + """axis number for the second module direction + + Returns + ------- + int + """ + return self._ax1 + + @property + def lor_spacing(self) -> float: + """spacing between the LOR endpoints in a module along the polygon + + Returns + ------- + float + """ + return self._lor_spacing + + # abstract method from base class to be implemented + def get_raw_lor_endpoints(self, + inds: npt.NDArray | None = None) -> npt.NDArray: + if inds is None: + inds = self.lor_endpoint_numbers + + side = inds // self.num_lor_endpoints_per_side + tmp = inds - side * self.num_lor_endpoints_per_side + tmp = self.xp.astype( + tmp, float) - (self.num_lor_endpoints_per_side / 2 - 0.5) + + phi = 2 * self.xp.pi * self.xp.astype(side, float) / self.num_sides + + lor_endpoints = self.xp.zeros((self.num_lor_endpoints, 3), + device=self.dev) + lor_endpoints[:, self.ax0] = self.xp.cos( + phi) * self.radius - self.xp.sin(phi) * self.lor_spacing * tmp + lor_endpoints[:, self.ax1] = self.xp.sin( + phi) * self.radius + self.xp.cos(phi) * self.lor_spacing * tmp + + return lor_endpoints + + +class ModularizedPETScannerGeometry: + """description of a PET scanner geometry consisting of LOR endpoint modules""" + + def __init__(self, modules: tuple[PETScannerModule]): + """ + Parameters + ---------- + modules : tuple[PETScannerModule] + a tuple of scanner modules + """ + + # member variable that determines whether we want to use + # a numpy or cupy array to store the array of all lor endpoints + self._modules = modules + self._num_modules = len(self._modules) + self._num_lor_endpoints_per_module = self.xp.asarray( + [x.num_lor_endpoints for x in self._modules], device=self.dev) + self._num_lor_endpoints = int( + self.xp.sum(self._num_lor_endpoints_per_module)) + + self.setup_all_lor_endpoints() + + def setup_all_lor_endpoints(self) -> None: + """calculate the position of all lor endpoints by iterating over + the modules and calculating the transformed coordinates of all + module endpoints + """ + + self._all_lor_endpoints_index_offset = self.xp.asarray([ + int(sum(self._num_lor_endpoints_per_module[:i])) + for i in range(size(self._num_lor_endpoints_per_module)) + ], + device=self.dev) + + self._all_lor_endpoints = self.xp.zeros((self._num_lor_endpoints, 3), + device=self.dev, + dtype=self.xp.float32) + + for i, module in enumerate(self._modules): + self._all_lor_endpoints[ + int(self._all_lor_endpoints_index_offset[i]):int( + self._all_lor_endpoints_index_offset[i] + + module.num_lor_endpoints), :] = module.get_lor_endpoints() + + self._all_lor_endpoints_module_number = [ + int(self._num_lor_endpoints_per_module[i]) * [i] + for i in range(self._num_modules) + ] + + self._all_lor_endpoints_module_number = self.xp.asarray( + [i for r in self._all_lor_endpoints_module_number for i in r], + device=self.dev) + + @property + def modules(self) -> tuple[PETScannerModule]: + """tuple of modules defining the scanner""" + return self._modules + + @property + def num_modules(self) -> int: + """the number of modules defining the scanner""" + return self._num_modules + + @property + def num_lor_endpoints_per_module(self) -> npt.NDArray: + """numpy array showing how many LOR endpoints are in every module""" + return self._num_lor_endpoints_per_module + + @property + def num_lor_endpoints(self) -> int: + """the total number of LOR endpoints in the scanner""" + return self._num_lor_endpoints + + @property + def all_lor_endpoints_index_offset(self) -> npt.NDArray: + """the offset in the linear (flattend) index for all LOR endpoints""" + return self._all_lor_endpoints_index_offset + + @property + def all_lor_endpoints_module_number(self) -> npt.NDArray: + """the module number of all LOR endpoints""" + return self._all_lor_endpoints_module_number + + @property + def all_lor_endpoints(self) -> npt.NDArray: + """the world coordinates of all LOR endpoints""" + return self._all_lor_endpoints + + @property + def xp(self) -> ModuleType: + """module indicating whether the LOR endpoints are stored as numpy or cupy array""" + return self._modules[0].xp + + @property + def dev(self) -> str: + return self._modules[0].dev + + def linear_lor_endpoint_index( + self, + module: npt.NDArray, + index_in_module: npt.NDArray, + ) -> npt.NDArray: + """transform the module + index_in_modules indices into a flattened / linear LOR endpoint index + + Parameters + ---------- + module : npt.NDArray + containing module numbers + index_in_module : npt.NDArray + containing index in modules + + Returns + ------- + npt.NDArray + the flattened LOR endpoint index + """ + # index_in_module = self._xp.asarray(index_in_module) + + return self.xp.take(self.all_lor_endpoints_index_offset, + module) + index_in_module + + def get_lor_endpoints(self, module: npt.NDArray, + index_in_module: npt.NDArray) -> npt.NDArray: + """get the coordinates for LOR endpoints defined by module and index in module + + Parameters + ---------- + module : npt.NDArray + the module number of the LOR endpoints + index_in_module : npt.NDArray + the index in module number of the LOR endpoints + + Returns + ------- + npt.NDArray | cpt.NDArray + the 3 world coordinates of the LOR endpoints + """ + return self.xp.take(self.all_lor_endpoints, + self.linear_lor_endpoint_index( + module, index_in_module), + axis=0) + + def show_lor_endpoints(self, + ax: plt.Axes, + show_linear_index: bool = True, + **kwargs) -> None: + """show all LOR endpoints in a 3D plot + + Parameters + ---------- + ax : plt.Axes + a 3D matplotlib axes + show_linear_index : bool, optional + annotate the LOR endpoints with the linear LOR endpoint index + **kwargs : keyword arguments + passed to show_lor_endpoints() of the scanner module + """ + for i, module in enumerate(self.modules): + if show_linear_index: + offset = np.asarray( + to_device(self.all_lor_endpoints_index_offset[i], 'cpu')) + prefix = f'' + else: + offset = 0 + prefix = f'{i},' + + module.show_lor_endpoints(ax, + annotation_offset=offset, + annotation_prefix=prefix, + **kwargs) + + +class RegularPolygonPETScannerGeometry(ModularizedPETScannerGeometry): + """description of a PET scanner geometry consisting stacked regular polygons""" + + def __init__(self, xp: ModuleType, dev: str, radius: float, num_sides: int, + num_lor_endpoints_per_side: int, lor_spacing: float, + num_rings: int, ring_positions: npt.NDArray, + symmetry_axis: int) -> None: + """ + Parameters + ---------- + xp: ModuleType + array module to use for storing the LOR endpoints + dev: str + device to use for storing the LOR endpoints + radius : float + radius of the scanner + num_sides : int + number of sides (faces) of each regular polygon + num_lor_endpoints_per_side : int + number of LOR endpoints in each side (face) of each polygon + lor_spacing : float + spacing between the LOR endpoints in each side + num_rings : int + the number of rings (regular polygons) + ring_positions : npt.NDArray + 1D array with the coordinate of the rings along the ring axis + symmetry_axis : int + the ring axis (0,1,2) + """ + + self._radius = radius + self._num_sides = num_sides + self._num_lor_endpoints_per_side = num_lor_endpoints_per_side + self._num_rings = num_rings + self._lor_spacing = lor_spacing + self._symmetry_axis = symmetry_axis + self._ring_positions = ring_positions + + if symmetry_axis == 0: + self._ax0 = 2 + self._ax1 = 1 + elif symmetry_axis == 1: + self._ax0 = 0 + self._ax1 = 2 + elif symmetry_axis == 2: + self._ax0 = 1 + self._ax1 = 0 + + modules = [] + + for ring in range(num_rings): + aff_mat = xp.eye(4, device=dev) + aff_mat[symmetry_axis, -1] = ring_positions[ring] + + modules.append( + RegularPolygonPETScannerModule( + xp, + dev, + radius, + num_sides, + num_lor_endpoints_per_side=num_lor_endpoints_per_side, + lor_spacing=lor_spacing, + affine_transformation_matrix=aff_mat, + ax0=self._ax0, + ax1=self._ax1)) + + modules = tuple(modules) + super().__init__(modules) + + self._all_lor_endpoints_index_in_ring = self.xp.arange( + self.num_lor_endpoints, device=dev + ) - self.all_lor_endpoints_ring_number * self.num_lor_endpoints_per_module[ + 0] + + @property + def radius(self) -> float: + """radius of the scanner""" + return self._radius + + @property + def num_sides(self) -> int: + """number of sides (faces) of each polygon""" + return self._num_sides + + @property + def num_lor_endpoints_per_side(self) -> int: + """number of LOR endpoints per side (face) in each polygon""" + return self._num_lor_endpoints_per_side + + @property + def num_rings(self) -> int: + """number of rings (regular polygons)""" + return self._num_rings + + @property + def lor_spacing(self) -> float: + """the spacing between the LOR endpoints in every side (face) of each polygon""" + return self._lor_spacing + + @property + def symmetry_axis(self) -> int: + """The symmetry axis. Also called axial (or ring) direction.""" + return self._symmetry_axis + + @property + def all_lor_endpoints_ring_number(self) -> npt.NDArray: + """the ring (regular polygon) number of all LOR endpoints""" + return self._all_lor_endpoints_module_number + + @property + def all_lor_endpoints_index_in_ring(self) -> npt.NDArray: + """the index withing the ring (regular polygon) number of all LOR endpoints""" + return self._all_lor_endpoints_index_in_ring + + @property + def num_lor_endpoints_per_ring(self) -> int: + """the number of LOR endpoints per ring (regular polygon)""" + return int(self._num_lor_endpoints_per_module[0]) + + @property + def ring_positions(self) -> npt.NDArray: + """the ring (regular polygon) positions""" + return self._ring_positions + + +class DemoPETScanner(RegularPolygonPETScannerGeometry): + + def __init__(self, + xp: ModuleType, + dev: str, + num_rings: int = 36, + symmetry_axis: int = 2) -> None: + + ring_positions = 5.32 * xp.arange( + num_rings, device=dev, dtype=xp.float32) + (xp.astype( + xp.arange(num_rings, device=dev) // 9, xp.float32)) * 2.8 + ring_positions -= 0.5 * xp.max(ring_positions) + super().__init__(xp, + dev, + radius=0.5 * (744.1 + 2 * 8.51), + num_sides=34, + num_lor_endpoints_per_side=16, + lor_spacing=4.03125, + num_rings=num_rings, + ring_positions=ring_positions, + symmetry_axis=symmetry_axis) + + +class PETLORDescriptor(abc.ABC): + """abstract base class to describe which modules / indices in modules of a + modularized PET scanner are in coincidence; defining geometrical LORs""" + + def __init__(self, scanner: ModularizedPETScannerGeometry) -> None: + """ + Parameters + ---------- + scanner : ModularizedPETScannerGeometry + a modularized PET scanner + """ + self._scanner = scanner + + @abc.abstractmethod + def get_lor_coordinates(self, + **kwargs) -> tuple[npt.ArrayLike, npt.ArrayLike]: + """return the start and end coordinates of all (or a subset of) LORs""" + raise NotImplementedError + + @property + def scanner(self) -> ModularizedPETScannerGeometry: + """the scanner for which coincidences are described""" + return self._scanner + + @property + def xp(self) -> ModuleType: + """array module to use for storing the LOR endpoints""" + return self.scanner.xp + + @property + def dev(self) -> str: + """device to use for storing the LOR endpoints""" + return self.scanner.dev + + +class RegularPolygonPETLORDescriptor(PETLORDescriptor): + + def __init__( + self, + scanner: RegularPolygonPETScannerGeometry, + radial_trim: int = 3, + max_ring_difference: int | None = None, + ) -> None: + """Coincidence descriptor for a regular polygon PET scanner where + we have coincidences within and between "rings (polygons of modules)" + The geometrical LORs can be sorted into a sinogram having a + "plane", "view" and "radial" axis. + + Parameters + ---------- + scanner : RegularPolygonPETScannerGeometry + a regular polygon PET scanner + radial_trim : int, optional + number of geometrial LORs to disregard in the radial direction, by default 3 + max_ring_difference : int | None, optional + maximim ring difference to consider for coincidences, by default None means + all ring differences are included + """ + + super().__init__(scanner) + + self._radial_trim = radial_trim + + if max_ring_difference is None: + self._max_ring_difference = self.scanner.num_rings - 1 + else: + self._max_ring_difference = max_ring_difference + + self._num_rad = (self.scanner.num_lor_endpoints_per_ring + + 1) - 2 * self._radial_trim + self._num_views = self.scanner.num_lor_endpoints_per_ring // 2 + + self._setup_plane_indices() + self._setup_view_indices() + + @property + def radial_trim(self) -> int: + """number of geometrial LORs to disregard in the radial direction""" + return self._radial_trim + + @property + def max_ring_difference(self) -> int: + """the maximum ring difference""" + return self._max_ring_difference + + @property + def num_planes(self) -> int: + """number of planes in the sinogram""" + return self._num_planes + + @property + def num_rad(self) -> int: + """number of radial elements in the sinogram""" + return self._num_rad + + @property + def num_views(self) -> int: + """number of views in the sinogram""" + return self._num_views + + @property + def start_plane_index(self) -> npt.NDArray: + """start plane for all planes""" + return self._start_plane_index + + @property + def end_plane_index(self) -> npt.NDArray: + """end plane for all planes""" + return self._end_plane_index + + @property + def start_in_ring_index(self) -> npt.NDArray: + """start index within ring for all views - shape (num_view, num_rad)""" + return self._start_in_ring_index + + @property + def end_in_ring_index(self) -> npt.NDArray: + """end index within ring for all views - shape (num_view, num_rad)""" + return self._end_in_ring_index + + def _setup_plane_indices(self) -> None: + """setup the start / end plane indices (similar to a Michelogram) + """ + self._start_plane_index = self.xp.arange(self.scanner.num_rings, + dtype=self.xp.int32, + device=self.dev) + self._end_plane_index = self.xp.arange(self.scanner.num_rings, + dtype=self.xp.int32, + device=self.dev) + + for i in range(1, self._max_ring_difference + 1): + tmp1 = self.xp.arange(self.scanner.num_rings - i, + dtype=self.xp.int16, + device=self.dev) + tmp2 = self.xp.arange(self.scanner.num_rings - i, + dtype=self.xp.int16, + device=self.dev) + i + + self._start_plane_index = self.xp.concat( + (self._start_plane_index, tmp1, tmp2)) + self._end_plane_index = self.xp.concat( + (self._end_plane_index, tmp2, tmp1)) + + self._num_planes = self._start_plane_index.shape[0] + + def _setup_view_indices(self) -> None: + """setup the start / end view indices + """ + n = self.scanner.num_lor_endpoints_per_ring + + m = 2 * (n // 2) + + self._start_in_ring_index = self.xp.zeros( + (self._num_views, self._num_rad), + dtype=self.xp.int32, + device=self.dev) + self._end_in_ring_index = self.xp.zeros( + (self._num_views, self._num_rad), + dtype=self.xp.int32, + device=self.dev) + + for view in np.arange(self._num_views): + self._start_in_ring_index[view, :] = ( + self.xp.concat( + (self.xp.arange(m) // 2, self.xp.asarray([n // 2]))) - + view)[self._radial_trim:-self._radial_trim] + self._end_in_ring_index[view, :] = ( + self.xp.concat( + (self.xp.asarray([-1]), -((self.xp.arange(m) + 4) // 2))) - + view)[self._radial_trim:-self._radial_trim] + + # shift the negative indices + self._start_in_ring_index = self.xp.where( + self._start_in_ring_index >= 0, self._start_in_ring_index, + self._start_in_ring_index + n) + self._end_in_ring_index = self.xp.where(self._end_in_ring_index >= 0, + self._end_in_ring_index, + self._end_in_ring_index + n) + + def get_lor_indices( + self, + views: None | npt.ArrayLike = None, + sinogram_order=SinogramSpatialAxisOrder.RVP + ) -> tuple[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike, npt.ArrayLike]: + """return the start and end indices of all LORs / or a subset of views + + Parameters + ---------- + views : None | npt.ArrayLike, optional + the views to consider, by default None means all views + sinogram_order : SinogramSpatialAxisOrder, optional + the order of the sinogram axes, by default SinogramSpatialAxisOrder.RVP + + Returns + ------- + start_mods, end_mods, start_inds, end_inds + """ + + if views is None: + views = self.xp.arange(self.num_views, device=self.dev) + + # setup the module and in_module (in_ring) indices for all LORs in PVR order + start_inring_inds = self.xp.reshape( + self.xp.take(self.start_in_ring_index, views, axis=0), (-1, )) + end_inring_inds = self.xp.reshape( + self.xp.take(self.end_in_ring_index, views, axis=0), (-1, )) + + start_mods, start_inds = self.xp.meshgrid(self.start_plane_index, + start_inring_inds, + indexing='ij') + end_mods, end_inds = self.xp.meshgrid(self.end_plane_index, + end_inring_inds, + indexing='ij') + + # reshape to PVR dimensions (radial moving fastest, planes moving slowest) + sinogram_spatial_shape = (self.num_planes, views.shape[0], + self.num_rad) + start_mods = self.xp.reshape(start_mods, sinogram_spatial_shape) + end_mods = self.xp.reshape(end_mods, sinogram_spatial_shape) + start_inds = self.xp.reshape(start_inds, sinogram_spatial_shape) + end_inds = self.xp.reshape(end_inds, sinogram_spatial_shape) + + if sinogram_order is not SinogramSpatialAxisOrder.PVR: + if sinogram_order is SinogramSpatialAxisOrder.RVP: + new_order = (2, 1, 0) + elif sinogram_order is SinogramSpatialAxisOrder.RPV: + new_order = (2, 0, 1) + elif sinogram_order is SinogramSpatialAxisOrder.VRP: + new_order = (1, 2, 0) + elif sinogram_order is SinogramSpatialAxisOrder.VPR: + new_order = (1, 0, 2) + elif sinogram_order is SinogramSpatialAxisOrder.PRV: + new_order = (0, 2, 1) + + start_mods = self.xp.permute_dims(start_mods, new_order) + end_mods = self.xp.permute_dims(end_mods, new_order) + + start_inds = self.xp.permute_dims(start_inds, new_order) + end_inds = self.xp.permute_dims(end_inds, new_order) + + return start_mods, end_mods, start_inds, end_inds + + def get_lor_coordinates( + self, + views: None | npt.ArrayLike = None, + sinogram_order=SinogramSpatialAxisOrder.RVP + ) -> tuple[npt.ArrayLike, npt.ArrayLike]: + + """return the start and end coordinates of all LORs / or a subset of views + + Parameters + ---------- + views : None | npt.ArrayLike, optional + the views to consider, by default None means all views + sinogram_order : SinogramSpatialAxisOrder, optional + the order of the sinogram axes, by default SinogramSpatialAxisOrder.RVP + + Returns + ------- + xstart, xend : npt.ArrayLike + 2 dimensional floating point arrays containing the start and end coordinates of all LORs + """ + + start_mods, end_mods, start_inds, end_inds = self.get_lor_indices(views, sinogram_order) + sinogram_spatial_shape = start_mods.shape + + start_mods = self.xp.reshape(start_mods, (-1, )) + start_inds = self.xp.reshape(start_inds, (-1, )) + + end_mods = self.xp.reshape(end_mods, (-1, )) + end_inds = self.xp.reshape(end_inds, (-1, )) + + x_start = self.xp.reshape( + self.scanner.get_lor_endpoints(start_mods, start_inds), + sinogram_spatial_shape + (3, )) + x_end = self.xp.reshape( + self.scanner.get_lor_endpoints(end_mods, end_inds), + sinogram_spatial_shape + (3, )) + + return x_start, x_end + + def show_views(self, + ax: plt.Axes, + views: npt.ArrayLike, + planes: npt.ArrayLike, + lw: float = 0.2, + **kwargs) -> None: + """show all LORs of a single view in a given plane + + Parameters + ---------- + ax : plt.Axes + a 3D matplotlib axes + view : int + the view number + plane : int + the plane number + lw : float, optional + the line width, by default 0.2 + """ + + xs, xe = self.get_lor_coordinates( + views=views, sinogram_order=SinogramSpatialAxisOrder.RVP) + xs = self.xp.reshape(self.xp.take(xs, planes, axis=2), (-1, 3)) + xe = self.xp.reshape(self.xp.take(xe, planes, axis=2), (-1, 3)) + + p1s = np.asarray(to_device(xs, 'cpu')) + p2s = np.asarray(to_device(xe, 'cpu')) + + ls = np.hstack([p1s, p2s]).copy() + ls = ls.reshape((-1, 2, 3)) + lc = Line3DCollection(ls, linewidths=lw, **kwargs) + ax.add_collection(lc) + + +class DemoPETScannerLORDescriptor(RegularPolygonPETLORDescriptor): + + def __init__(self, + xp: ModuleType, + dev: str, + num_rings: int = 9, + radial_trim: int = 65, + max_ring_difference: int | None = None, + symmetry_axis: int = 2) -> None: + + scanner = DemoPETScanner(xp, + dev, + num_rings, + symmetry_axis=symmetry_axis) + + super().__init__(scanner, + radial_trim=radial_trim, + max_ring_difference=max_ring_difference) + + +class RegularPolygonPETProjector(parallelproj.LinearOperator): + + def __init__(self, + lor_descriptor: RegularPolygonPETLORDescriptor, + img_shape: tuple[int, int, int], + voxel_size: tuple[float, float, float], + img_origin: None | npt.ArrayLike = None, + views: None | npt.ArrayLike = None, + resolution_model: None | parallelproj.LinearOperator = None, + tof: bool = False): + """Regular polygon PET projector + + Parameters + ---------- + lor_descriptor : RegularPolygonPETLORDescriptor + descriptor of the LOR start / end points + img_shape : tuple[int, int, int] + shape of the image to be projected + voxel_size : tuple[float, float, float] + the voxel size of the image to be projected + img_origin : None | npt.ArrayLike, optional + the origin of the image to be projected, by default None + means that image is "centered" in the scanner + views : None | npt.ArrayLike, optional + sinogram views to be projected, by default None + means that all views are being projected + resolution_model : None | parallelproj.LinearOperator, optional + an image-based resolution model applied before forward projection, by default None + means an isotropic 4.5mm FWHM Gaussian smoothing is used + tof: bool, optional, default False + whether to use non-TOF or TOF projections + """ + + super().__init__() + self._dev = lor_descriptor.dev + + self._lor_descriptor = lor_descriptor + self._img_shape = img_shape + self._voxel_size = self.xp.asarray(voxel_size, + dtype=self.xp.float32, + device=self._dev) + + if img_origin is None: + self._img_origin = (-(self.xp.asarray( + self._img_shape, dtype=self.xp.float32, device=self._dev) / 2) + + 0.5) * self._voxel_size + else: + self._img_origin = self.xp.asarray(img_origin, + dtype=self.xp.float32, + device=self._dev) + + if views is None: + self._views = self.xp.arange(self._lor_descriptor.num_views, + device=self._dev) + else: + self._views = views + + self._resolution_model = resolution_model + + self._xstart, self._xend = lor_descriptor.get_lor_coordinates( + views=self._views, sinogram_order=SinogramSpatialAxisOrder['RVP']) + + self._tof = tof + self._tof_parameters = TOFParameters() + + @property + def in_shape(self) -> tuple[int, int, int]: + return self._img_shape + + @property + def out_shape(self) -> tuple[int, int, int]: + if self.tof: + out_shape = (self._lor_descriptor.num_rad, self._views.shape[0], + self._lor_descriptor.num_planes, + self.tof_parameters.num_tofbins) + else: + out_shape = (self._lor_descriptor.num_rad, self._views.shape[0], + self._lor_descriptor.num_planes) + + return out_shape + + @property + def xp(self) -> ModuleType: + return self._lor_descriptor.xp + + @property + def tof(self) -> bool: + return self._tof + + @tof.setter + def tof(self, value: bool) -> None: + if not isinstance(value, bool): + raise ValueError('tof must be a boolean') + self._tof = value + + @property + def tof_parameters(self) -> TOFParameters: + return self._tof_parameters + + @tof_parameters.setter + def tof_parameters(self, value: TOFParameters) -> None: + if not isinstance(value, TOFParameters): + raise ValueError('tof_parameters must be a TOFParameters object') + self._tof_parameters = value + + @property + def img_origin(self) -> npt.NDArray: + return self._img_origin + + def _apply(self, x): + """nonTOF forward projection of input image x including image based resolution model""" + + dev = device(x) + + if self._resolution_model is not None: + x_sm = self._resolution_model(x) + else: + x_sm = x + + if not self.tof: + x_fwd = parallelproj.joseph3d_fwd(self._xstart, self._xend, x_sm, + self._img_origin, + self._voxel_size) + else: + x_fwd = parallelproj.joseph3d_fwd_tof_sino( + self._xstart, self._xend, x_sm, self._img_origin, + self._voxel_size, self._tof_parameters.tofbin_width, + self.xp.asarray([self._tof_parameters.sigma_tof], + dtype=self.xp.float32, + device=dev), + self.xp.asarray([self._tof_parameters.tofcenter_offset], + dtype=self.xp.float32, + device=dev), self.tof_parameters.num_sigmas, + self.tof_parameters.num_tofbins) + + return x_fwd + + def _adjoint(self, y): + """nonTOF back projection of sinogram y""" + dev = device(y) + + if not self.tof: + y_back = parallelproj.joseph3d_back(self._xstart, self._xend, + self._img_shape, + self._img_origin, + self._voxel_size, y) + else: + y_back = parallelproj.joseph3d_back_tof_sino( + self._xstart, self._xend, self._img_shape, self._img_origin, + self._voxel_size, y, self._tof_parameters.tofbin_width, + self.xp.asarray([self._tof_parameters.sigma_tof], + dtype=self.xp.float32, + device=dev), + self.xp.asarray([self._tof_parameters.tofcenter_offset], + dtype=self.xp.float32, + device=dev), self.tof_parameters.num_sigmas, + self.tof_parameters.num_tofbins) + + if self._resolution_model is not None: + y_back = self._resolution_model.adjoint(y_back) + + return y_back + + + +def distributed_subset_order(n: int) -> list[int]: + """subset order that maximizes distance between subsets + + Parameters + ---------- + n : int + number of subsets + + Returns + ------- + list[int] + """ + l = [x for x in range(n)] + o = [] + + for i in range(n): + if (i % 2) == 0: + o.append(l.pop(0)) + else: + o.append(l.pop(len(l)//2)) + + return o +