Skip to content

Commit

Permalink
add LM recon (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Nov 14, 2023
1 parent a4c72d1 commit ffc63f5
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 24 deletions.
107 changes: 83 additions & 24 deletions python/parallelproj_sim.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
#TODO: - absolute scale of recon (maybe with sinogram recon)
# - additive MLEM

from __future__ import annotations

import parallelproj
import utils
import array_api_compat.numpy as np
import matplotlib.pyplot as plt
from array_api_compat import to_device
from scipy.ndimage import gaussian_filter

# device variable (cpu or cuda) that determines whether calculations
# are performed on the cpu or cuda gpu

dev = "cpu"
expected_num_trues = 5e6
expected_num_trues = 1e6
num_iter = 2
num_subsets = 20
np.random.seed(1)

# ----------------------------------------------------------------------------
Expand Down Expand Up @@ -42,6 +49,7 @@
# setup an image containing a box
img = np.zeros(img_shape, dtype=np.float32, device=dev)
img[(n0 // 4) : (3 * n0 // 4), (n1 // 4) : (3 * n1 // 4), :] = 1
img[(7*n0 // 16) : (9 * n0 // 16), (6*n1 // 16) : (8 * n1 // 16), :] = 2

# ----------------------------------------------------------------------------
# ----------------------------------------------------------------------------
Expand All @@ -63,9 +71,6 @@
# calculate the sensitivity image
sens_img = projector.adjoint(np.ones(noise_free_sinogram.shape, device=dev, dtype=np.float32))

# get a lookup table that contains the world coordinates of all scanner detectors
scanner_lut = lor_descriptor.scanner.all_lor_endpoints

# get the two dimensional indices of all sinogram bins
start_mods, end_mods, start_inds, end_inds = lor_descriptor.get_lor_indices()

Expand All @@ -81,6 +86,9 @@
sino_det_start_index = sino_det_start_index.ravel()
sino_det_end_index = sino_det_end_index.ravel()

# ----------------------------------------------------------------------------
# ----------------------------------------------------------------------------
# --- convert the index sinograms in to listmode data ------------------------
# ----------------------------------------------------------------------------
# ----------------------------------------------------------------------------

Expand All @@ -92,39 +100,90 @@
event_det_id_1 = sino_det_start_index[event_sino_inds]
event_det_id_2 = sino_det_end_index[event_sino_inds]

# ----------------------------------------------------------------------------
# ----------------------------------------------------------------------------
# ----------------------------------------------------------------------------
# ----------------------------------------------------------------------------
print(f'number of events: {event_det_id_1.shape[0]}')

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#---- convert LM detector ID arrays into PRD here ---------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------

# get a lookup table that contains the world coordinates of all scanner detectors
# this is a 2D array of shape (num_detectors, 3)
scanner_lut = lor_descriptor.scanner.all_lor_endpoints

#
#
#
#
#
#

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#---- read events back from PRD here ----------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------

#
#
#
#
#
#

# hack until we have the reader / writer implemented
xstart = scanner_lut[event_det_id_1, :]
xend = scanner_lut[event_det_id_2, :]

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#---- LM recon using the event detector IDs and the scanner LUT -------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------

recon = np.ones(img_shape, dtype=np.float32, device=dev)

for it in range(num_iter):
for isub in range(num_subsets):
print(f'it {(it+1):03} / ss {(isub+1):03}', end='\r')
xs_sub = xstart[isub::num_subsets,:]
xe_sub = xend[isub::num_subsets,:]
exp = parallelproj.joseph3d_fwd(xs_sub, xe_sub, recon, projector.img_origin, voxel_size)
tmp = parallelproj.joseph3d_back(xs_sub, xe_sub, img_shape, projector.img_origin, voxel_size, 1/exp)
recon *= (tmp / (sens_img / num_subsets))

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
# show the scanner geometry and one view in one sinogram plane
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection="3d")
lor_descriptor.scanner.show_lor_endpoints(ax, show_linear_index=True, annotation_fontsize=4)

# plot the LORs of the first n events
for i, col in enumerate(('red','green','blue')):
xstart = scanner_lut[event_det_id_1[i], :]
xend = scanner_lut[event_det_id_2[i], :]
xs = scanner_lut[event_det_id_1[i], :]
xe= scanner_lut[event_det_id_2[i], :]

ax.plot(
[xstart[0], xend[0]],
[xstart[1], xend[1]],
[xstart[2], xend[2]],
[xs[0], xe[0]],
[xs[1], xe[1]],
[xs[2], xe[2]],
color=col,
linewidth=1.,
)

fig.tight_layout()
fig.show()
#
#fig2, ax2 = plt.subplots(1, 3, figsize=(15, 5))
#ax2[0].imshow(np.asarray(to_device(img[:, :, 3], "cpu")))
#if projector.tof:
# ax2[1].imshow(np.asarray(to_device(noise_free_sinogram[:, :, 0, 15], "cpu")))
#else:
# ax2[1].imshow(np.asarray(to_device(noise_free_sinogram[:, :, 0], "cpu")))
#ax2[2].imshow(np.asarray(to_device(sens_img[:, :, 3], "cpu")))
#fig2.tight_layout()
#fig2.show()
#

vmax = 1.2*img.max()
fig2, ax2 = plt.subplots(1, 4, figsize=(16, 4))
ax2[0].imshow(np.asarray(to_device(img[:, :, 1], "cpu")), vmin = 0, vmax = vmax, cmap = 'Greys')
if projector.tof:
ax2[1].imshow(np.asarray(to_device(noise_free_sinogram[:, :, 0, 15], "cpu")), cmap = 'Greys')
else:
ax2[1].imshow(np.asarray(to_device(noise_free_sinogram[:, :, 0], "cpu")), cmap = 'Greys')
ax2[2].imshow(np.asarray(to_device(recon[:, :, 1], "cpu")), vmin = 0, vmax = vmax, cmap = 'Greys')
ax2[3].imshow(gaussian_filter(np.asarray(to_device(recon[:, :, 1], "cpu")), 1.5), vmin = 0, vmax = vmax, cmap = 'Greys')
fig2.tight_layout()
fig2.show()
4 changes: 4 additions & 0 deletions python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,10 @@ def tof_parameters(self, value: TOFParameters) -> None:
raise ValueError('tof_parameters must be a TOFParameters object')
self._tof_parameters = value

@property
def img_origin(self) -> npt.NDArray:
return self._img_origin

def _apply(self, x):
"""nonTOF forward projection of input image x including image based resolution model"""

Expand Down

0 comments on commit ffc63f5

Please sign in to comment.