Skip to content

Commit

Permalink
Object-oriented interface, with fit and apply methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
arokem committed Apr 13, 2020
1 parent a3bfa1a commit 18f92d9
Showing 1 changed file with 68 additions and 52 deletions.
120 changes: 68 additions & 52 deletions dmriprep/utils/register.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Linear affine registration tools for motion correction.
"""
import attr

import numpy as np
import nibabel as nb
from dipy.align.metrics import CCMetric, EMMetric, SSDMetric
Expand Down Expand Up @@ -72,7 +74,7 @@ def c_of_mass(
):
transform = transform_centers_of_mass(static, static_affine, moving, moving_affine)
transformed = transform.transform(moving)
return transformed, transform.affine
return transform


def translation(
Expand All @@ -89,7 +91,7 @@ def translation(
starting_affine=starting_affine,
)

return translation.transform(moving), translation.affine
return translation


def rigid(
Expand All @@ -105,12 +107,13 @@ def rigid(
moving_affine,
starting_affine=starting_affine,
)
return rigid.transform(moving), rigid.affine
return rigid


def affine(
moving, static, static_affine, moving_affine, reg, starting_affine, params0=None
):
def affine(moving, static, static_affine, moving_affine, reg, starting_affine,
params0=None):
"""
"""
transform = AffineTransform3D()
affine = reg.optimize(
static,
Expand All @@ -122,49 +125,62 @@ def affine(
starting_affine=starting_affine,
)

return affine.transform(moving), affine.affine


def affine_registration(
moving,
static,
nbins,
sampling_prop,
metric,
pipeline,
level_iters,
sigmas,
factors,
params0,
):

"""
Find the affine transformation between two 3D images.
Parameters
----------
"""
# Define the Affine registration object we'll use with the chosen metric:
use_metric = affine_metric_dict[metric](nbins, sampling_prop)
affreg = AffineRegistration(
metric=use_metric, level_iters=level_iters, sigmas=sigmas, factors=factors
)

if not params0:
starting_affine = np.eye(4)
else:
starting_affine = params0

# Go through the selected transformation:
for func in pipeline:
transformed, starting_affine = func(
np.asarray(moving.dataobj),
np.asarray(static.dataobj),
static.affine,
moving.affine,
affreg,
starting_affine,
params0,
)
return nb.Nifti1Image(np.array(transformed), static.affine), starting_affine
return affine


@attr.s(slots=True, frozen=True)
class AffineRegistration():
def __init__(self):
nbins = attr.ib(default=32)
sampling_prop = attr.ib(default=1.0)
metric = attr.ib(default="MI")
level_iters = attr.ib(default=[10000, 1000, 100])
sigmas = attr.ib(defaults=[3, 1, 0.0])
factors = attr.ib(defaults=[4, 2, 1])
pipeline = attr.ib(defaults=[c_of_mass, translation, rigid, affine])

def fit(self, static, moving, params0=None):
"""
static, moving : nib.Nifti1Image class images
"""
if params0 is None:
starting_affine = np.eye(4)
else:
starting_affine = params0

use_metric = affine_metric_dict[self.metric](self.nbins,
self.sampling_prop)
affreg = AffineRegistration(
metric=use_metric,
level_iters=self.level_iters,
sigmas=self.sigmas,
factors=self.factors)

# Go through the selected transformation:
for func in self.pipeline:
transform = func(
np.asarray(moving.dataobj),
np.asarray(static.dataobj),
static.affine,
moving.affine,
affreg,
starting_affine,
params0,
)
starting_affine = transform.affine

self.static_affine_ = static.affine
self.moving_affine_ = moving.affine
self.affine_ = starting_affine
self.reg_ = AffineMap(starting_affine,
static.shape, static.affine,
moving.shape, moving.affine)

def apply(self, moving):
"""
"""
data = moving.get_fdata()
assert np.all(moving.affine, self.moving_affine_)
return nb.Nifti1Image(np.array(self.reg_.transform(data)),
self.static_affine_)

0 comments on commit 18f92d9

Please sign in to comment.