diff --git a/Wrappers/Python/cil/optimisation/functions/AbsFunction.py b/Wrappers/Python/cil/optimisation/functions/AbsFunction.py index b922eb36ae..5ae55d1006 100644 --- a/Wrappers/Python/cil/optimisation/functions/AbsFunction.py +++ b/Wrappers/Python/cil/optimisation/functions/AbsFunction.py @@ -26,7 +26,7 @@ import numpy as np from cil.optimisation.functions import Function -from cil.framework import DataContainer +from cil.framework import DataContainer, ImageGeometry from typing import Optional import warnings import logging @@ -147,10 +147,15 @@ def _take_abs_input(func): def _take_abs_decorator(self, x: DataContainer, *args, **kwargs): real_dtype, _ = _get_real_complex_dtype(x) - rgeo = x.geometry.copy() - rgeo.dtype = real_dtype - r = rgeo.allocate(0) - r.fill(np.abs(x.as_array()).astype(real_dtype)) + try: + rgeo = x.geometry.copy() + rgeo.dtype = real_dtype + r = rgeo.allocate(0) + r.fill(np.abs(x.as_array()).astype(real_dtype)) + except AttributeError as excp: + rgeo = ImageGeometry(*x.shape[::-1]) + r = rgeo.allocate(None, dtype=real_dtype) + r.fill(np.abs(x.asarray())) fval = func(r, *args, **kwargs) return fval return _take_abs_decorator @@ -167,12 +172,17 @@ def _abs_project_decorator(self, x: DataContainer, *args, **kwargs): real_dtype, complex_dtype = _get_real_complex_dtype(x) - - rgeo = x.geometry.copy() - rgeo.dtype = real_dtype - r = rgeo.allocate(None) - r.fill(np.abs(x.as_array()).astype(real_dtype)) - Phi = np.exp((1j*np.angle(x.array))) + try: + rgeo = x.geometry.copy() + rgeo.dtype = real_dtype + r = rgeo.allocate(None) + r.fill(np.abs(x.as_array()).astype(real_dtype)) + Phi = np.exp((1j*np.angle(x.array))) + except AttributeError as excp: + rgeo = ImageGeometry(*x.shape[::-1]) + r = rgeo.allocate(None, dtype=real_dtype) + r.fill(np.abs(x.asarray())) + Phi = np.exp((1j*np.angle(x.asarray()))) out = kwargs.pop('out', None) fvals = func(r, *args, **kwargs) @@ -185,9 +195,14 @@ def _abs_project_decorator(self, x: DataContainer, *args, **kwargs): y = r.copy() fvals_np = fvals.as_array() while np.any(fvals_np < 0): - tmp = fvals_np - 0.5*y.as_array() + 0.5*r.as_array() - tmp[tmp < 0] = 0. - y += DataContainter(tmp, y.geometry) - fvals + tmp = r - y + tmp *= 0.5 + tmp += fvals + tmparr = tmp.as_array() + tmparr[tmparr < 0] = 0. + tmp.fill(tmparr) + y += tmp + y -= fvals fvals = func(y, *args, **kwargs) cts += 1 if cts > 10: @@ -196,11 +211,12 @@ def _abs_project_decorator(self, x: DataContainer, *args, **kwargs): break if out is None: - out_geom = x.geometry.copy() - out = out_geom.allocate(None) + out = x * 0 if np.isreal(x.as_array()).all(): + print("Input is real, returning real output") out.fill( np.real(fvals_np.astype(complex_dtype)*Phi)) else: + print("Input is complex, returning complex output") out.fill( fvals_np.astype(complex_dtype)*Phi) return out return _abs_project_decorator @@ -209,7 +225,7 @@ def _abs_project_decorator(self, x: DataContainer, *args, **kwargs): def _get_real_complex_dtype(x: DataContainer): '''An internal function to find the type of x and set the corresponding real and complex data types ''' - x_dtype = x.as_array().dtype + x_dtype = x.dtype if np.issubdtype(x_dtype, np.complexfloating): complex_dtype = x_dtype complex_example = np.array([1 + 1j], dtype=x_dtype)