Skip to content

Commit

Permalink
Merge pull request #500 from odlgroup/issue-266__inplace_raytrafo_test
Browse files Browse the repository at this point in the history
TST: add inplace test for tomo projectors, closes #266
  • Loading branch information
adler-j authored Aug 15, 2016
2 parents 647736a + a980ad1 commit b27ad05
Show file tree
Hide file tree
Showing 18 changed files with 50 additions and 39 deletions.
2 changes: 1 addition & 1 deletion odl/discr/discr_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ def _evaluate(self, indices, norm_distances, out=None):
# TODO: determine best summation order from array strides
for lh, w_lo, w_hi in zip(lo_hi, low_weights, high_weights):

# We don't multiply in place to exploit the cheap operations
# We don't multiply in-place to exploit the cheap operations
# in the beginning: sizes grow gradually as following:
# (n, 1, 1, ...) -> (n, m, 1, ...) -> ...
# Hence, it is faster to build up the weight array instead
Expand Down
4 changes: 2 additions & 2 deletions odl/operator/default_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def _call(self, x, out=None):
>>> vec = r3.element([1, 2, 3])
>>> out = r3.element()
>>> op = ScalingOperator(r3, 2.0)
>>> op(vec, out) # In place, Returns out
>>> op(vec, out) # In-place, Returns out
rn(3).element([2.0, 4.0, 6.0])
>>> out
rn(3).element([2.0, 4.0, 6.0])
>>> op(vec) # Out of place
>>> op(vec) # Out-of-place
rn(3).element([2.0, 4.0, 6.0])
"""
if out is None:
Expand Down
10 changes: 5 additions & 5 deletions odl/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ class described in the following.
**Dual-use evaluation:** ``_call(self, x, out=None[, **kwargs])``
Evaluate in place if ``out`` is given, otherwise out of place.
Evaluate in-place if ``out`` is given, otherwise out-of-place.
**Parameters:**
Expand Down Expand Up @@ -434,7 +434,7 @@ def __new__(cls, *args, **kwargs):
# Dual-use _call
cls._call_in_place = cls._call_out_of_place = cls._call
else:
# In-place only _call
# In-place-only _call
cls._call_in_place = cls._call
cls._call_out_of_place = _default_call_out_of_place

Expand Down Expand Up @@ -519,7 +519,7 @@ def _call(self, x, out=None, **kwargs):
- If you just write a quick implementation or are not too
worried about efficiency, it may be easiest to write the
evaluation *out of place*.
evaluation *out-of-place*.
- We recommend advanced and performance-aware users to implement
the *in-place* pattern if the wrapped code supports it.
In-place evaluation is usually significantly faster since it
Expand Down Expand Up @@ -1069,7 +1069,7 @@ def _call(self, x, out=None):
>>> op = odl.IdentityOperator(r3)
>>> x = r3.element([1, 2, 3])
>>> out = r3.element()
>>> OperatorSum(op, op)(x, out) # In place, returns out
>>> OperatorSum(op, op)(x, out) # In-place, returns out
rn(3).element([2.0, 4.0, 6.0])
>>> out
rn(3).element([2.0, 4.0, 6.0])
Expand Down Expand Up @@ -2121,7 +2121,7 @@ def _call(self, x, out=None):
attrs['_call_in_place'] = _call
attrs['_call_out_of_place'] = _call
else:
# In-place only _call
# In-place-only _call

def _call(self, x, out):
return call(x, out)
Expand Down
2 changes: 1 addition & 1 deletion odl/solvers/advanced/chambolle_pock.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def chambolle_pock_solver(op, x, tau, sigma, proximal_primal, proximal_dual,
op : `Operator`
Forward operator, the operator ``K`` in the problem formulation.
x : element in the domain of ``op``
Starting point of the iteration, updated in place.
Starting point of the iteration, updated in-place.
tau : positive `float`
Step size parameter for the update of the primal variable.
Controls the extent to which ``proximal_primal`` maps points
Expand Down
2 changes: 1 addition & 1 deletion odl/solvers/advanced/douglas_rachford.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def douglas_rachford_pd(x, prox_f, prox_cc_g, L, tau, sigma, niter,
Parameters
----------
x : `LinearSpaceVector`
Initial point, updated in place.
Initial point, updated in-place.
prox_f : `callable`
`proximal factory` for the function ``f``.
prox_cc_g : `sequence` of `callable`'s
Expand Down
2 changes: 1 addition & 1 deletion odl/solvers/advanced/forward_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def forward_backward_pd(x, prox_f, prox_cc_g, L, grad_h, tau, sigma, niter,
Parameters
----------
x : `LinearSpaceVector`
Initial point, updated in place.
Initial point, updated in-place.
prox_f : `callable`
`Proximal factory` for the functional ``f``.
prox_cc_g : `sequence` of `callable`'s
Expand Down
2 changes: 1 addition & 1 deletion odl/solvers/iterative/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def landweber(op, x, rhs, niter=1, omega=1, projection=None, callback=None):
projection : `callable`, optional
Function that can be used to modify the iterates in each iteration,
for example enforcing positivity. The function should take one
argument and modify it in place.
argument and modify it in-place.
callback : `callable`, optional
Object executing code per iteration, e.g. plotting each iterate
Expand Down
2 changes: 1 addition & 1 deletion odl/solvers/scalar/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def steepest_descent(grad, x, niter=1, line_search=1, projection=None,
projection : `callable`, optional
Function that can be used to modify the iterates in each iteration,
for example enforcing positivity. The function should take one
argument and modify it inplace.
argument and modify it in-place.
callback : `callable`, optional
Object executing code per iteration, e.g. plotting each iterate
Expand Down
4 changes: 2 additions & 2 deletions odl/space/fspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,15 @@ def __init__(self, fset, fcall):
self._call_out_optional = call_out_optional

if not call_has_out:
# Out-of-place only
# Out-of-place-only
self._call_in_place = preload_first_arg(self, 'in-place')(
_default_in_place)
self._call_out_of_place = fcall
elif call_out_optional:
# Dual-use
self._call_in_place = self._call_out_of_place = fcall
else:
# In-place only
# In-place-only
self._call_in_place = fcall
self._call_out_of_place = preload_first_arg(self, 'out-of-place')(
_default_out_of_place)
Expand Down
6 changes: 3 additions & 3 deletions odl/trafos/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ def dft_preprocess_data(arr, shift=True, axes=None, sign='-', out=None):
out[:] = arr

if is_real_dtype(out.dtype) and not shift:
raise ValueError('cannot pre-process real input in place without '
raise ValueError('cannot pre-process real input in-place without '
'shift')

if sign == '-':
Expand Down Expand Up @@ -2176,7 +2176,7 @@ def _call_numpy(self, x):
Result of the transform
"""
# Pre-processing before calculating the DFT
# Note: since the FFT call is out of place, it does not matter if
# Note: since the FFT call is out-of-place, it does not matter if
# preprocess produces real or complex output in the R2C variant.
# There is no significant time difference between (full) R2C and
# C2C DFT in Numpy.
Expand Down Expand Up @@ -2505,7 +2505,7 @@ def _call_pyfftw(self, x, out, **kwargs):
fft_arr /= np.prod(np.take(self.domain.shape, self.axes))

# Post-processing in IFT = pre-processing in FT. In-place for
# C2C and HC2R. For C2R, this is out of place and discards the
# C2C and HC2R. For C2R, this is out-of-place and discards the
# imaginary part.
self._postprocess(fft_arr, out=out)
return out
Expand Down
8 changes: 4 additions & 4 deletions test/discr/lp_discr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def test_ufunc(fn_impl, ufunc):
assert isinstance(data_vector.ufunc,
odl.util.ufuncs.DiscreteLpUFuncs)

# Out of place:
# Out-of-place:
np_result = ufunc(*in_arrays)
vec_fun = getattr(data_vector.ufunc, name)
odl_result = vec_fun(*in_vectors)
Expand All @@ -776,20 +776,20 @@ def test_ufunc(fn_impl, ufunc):
for i in range(n_out):
assert isinstance(odl_result[i], space.element_type)

# In place:
# In-place:
np_result = ufunc(*(in_arrays + out_arrays))
vec_fun = getattr(data_vector.ufunc, name)
odl_result = vec_fun(*(in_vectors + out_vectors))
assert all_almost_equal(np_result, odl_result)

# Test inplace actually holds:
# Test in-place actually holds:
if n_out == 1:
assert odl_result is out_vectors[0]
elif n_out > 1:
for i in range(n_out):
assert odl_result[i] is out_vectors[i]

# Test out of place with np data
# Test out-of-place with np data
np_result = ufunc(*in_arrays)
vec_fun = getattr(data_vector.ufunc, name)
odl_result = vec_fun(*in_arrays[1:])
Expand Down
4 changes: 2 additions & 2 deletions test/operator/operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,12 @@ def test_linear_adjoint():
xvec = Aop.range.element(x)
outvec = Aop.domain.element()

# Using inplace adjoint
# Using in-place adjoint
Aop.adjoint(xvec, outvec)
np.dot(A.T, x, out)
assert all_almost_equal(out, outvec)

# Using out of place method
# Using out-of-place method
assert all_almost_equal(Aop.adjoint(xvec), np.dot(A.T, x))


Expand Down
4 changes: 2 additions & 2 deletions test/space/fspace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_fspace_vector_eval_real():
with pytest.raises(TypeError): # ValueError: invalid vectorized input
f_vec_oop(points[0])

# In-place only
# In-place-only
out_arr = np.empty((5,), dtype='float64')
out_mg = np.empty((2, 3), dtype='float64')

Expand Down Expand Up @@ -287,7 +287,7 @@ def test_fspace_vector_eval_complex():
with pytest.raises(TypeError): # ValueError: invalid vectorized input
f_vec_oop(points[0])

# In-place only
# In-place-only
out_arr = np.empty((5,), dtype='complex128')
out_mg = np.empty((2, 3), dtype='complex128')

Expand Down
6 changes: 3 additions & 3 deletions test/space/ntuples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,7 +1718,7 @@ def test_ufuncs(fn, ufunc):
in_vectors = vectors[1:n_args]
out_vectors = vectors[n_args:]

# Out of place:
# Out-of-place:
np_result = npufunc(*in_arrays)
vec_fun = getattr(data_vector.ufunc, name)
odl_result = vec_fun(*in_vectors)
Expand All @@ -1731,13 +1731,13 @@ def test_ufuncs(fn, ufunc):
for i in range(n_out):
assert isinstance(odl_result[i], fn.element_type)

# In place:
# In-place:
np_result = npufunc(*(in_arrays + out_arrays))
vec_fun = getattr(data_vector.ufunc, name)
odl_result = vec_fun(*(in_vectors + out_vectors))
assert all_almost_equal(np_result, odl_result)

# Test inplace actually holds:
# Test in-place actually holds:
if n_out == 1:
assert odl_result is out_vectors[0]
elif n_out > 1:
Expand Down
2 changes: 1 addition & 1 deletion test/space/pspace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def test_power_lincomb():
assert all_almost_equal(z, expected)


def test_power_inplace_modify():
def test_power_in_place_modify():
H = odl.rn(2)
HxH = odl.ProductSpace(H, 2)

Expand Down
15 changes: 13 additions & 2 deletions test/tomo/operators/ray_trafo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,14 @@ def projector(request):
raise ValueError('geom not valid')


def test_projector(projector):
@pytest.fixture(scope="module",
params=[True, False],
ids=[' in-place ', ' out-of-place '])
def in_place(request):
return request.param


def test_projector(projector, in_place):
"""Test discrete Ray transform forward projection."""

# TODO: this needs to be improved
Expand All @@ -166,7 +173,11 @@ def test_projector(projector):
vol = projector.domain.one()

# Calculate projection
proj = projector(vol)
if in_place:
proj = projector.range.zero()
projector(vol, out=proj)
else:
proj = projector(vol)

# We expect maximum value to be along diagonal
expected_max = projector.domain.partition.extent()[0] * np.sqrt(2)
Expand Down
10 changes: 5 additions & 5 deletions test/trafos/fourier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ def test_dft_preprocess_data(sign):
correct_arr.append((1 + 1j) * (1 - 2 * ((i + j + k) % 2)))

arr = np.ones(shape, dtype='complex64') * (1 + 1j)
preproc = dft_preprocess_data(arr, shift=True, sign=sign) # out of place
dft_preprocess_data(arr, shift=True, out=arr, sign=sign) # in place
preproc = dft_preprocess_data(arr, shift=True, sign=sign) # out-of-place
dft_preprocess_data(arr, shift=True, out=arr, sign=sign) # in-place

assert all_almost_equal(preproc.ravel(), correct_arr)
assert all_almost_equal(arr.ravel(), correct_arr)
Expand Down Expand Up @@ -387,10 +387,10 @@ def test_dft_preprocess_data_halfcomplex(sign):
correct_arr.append(1 - 2 * ((i + j + k) % 2))

arr = np.ones(shape, dtype='float64')
preproc = dft_preprocess_data(arr, shift=True, sign=sign) # out of place
preproc = dft_preprocess_data(arr, shift=True, sign=sign) # out-of-place
out = np.empty_like(arr)
dft_preprocess_data(arr, shift=True, out=out, sign=sign) # in place
dft_preprocess_data(arr, shift=True, out=arr, sign=sign) # in place
dft_preprocess_data(arr, shift=True, out=out, sign=sign) # in-place
dft_preprocess_data(arr, shift=True, out=arr, sign=sign) # in-place
assert all_almost_equal(preproc.ravel(), correct_arr)
assert all_almost_equal(arr.ravel(), correct_arr)
assert all_almost_equal(out.ravel(), correct_arr)
Expand Down
4 changes: 2 additions & 2 deletions test/util/vectorization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def simple_func(x):
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]]

# Out of place
# Out-of-place
out = simple_func(arr)
assert isinstance(out, np.ndarray)
assert out.dtype == np.dtype('int')
Expand Down Expand Up @@ -298,7 +298,7 @@ def simple_func(x):
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]]

# Out of place
# Out-of-place
out = simple_func(arr)
assert isinstance(out, np.ndarray)
assert is_int_dtype(out.dtype)
Expand Down

0 comments on commit b27ad05

Please sign in to comment.