Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions Wrappers/Python/cil/optimisation/functions/AbsFunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import numpy as np
from cil.optimisation.functions import Function
from cil.framework import DataContainer
from cil.framework import DataContainer, ImageGeometry

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataContainer is only needed for type-hints reasons.

from typing import Optional
import warnings
import logging
Expand Down Expand Up @@ -147,10 +147,15 @@ def _take_abs_input(func):
def _take_abs_decorator(self, x: DataContainer, *args, **kwargs):
real_dtype, _ = _get_real_complex_dtype(x)

rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(0)
r.fill(np.abs(x.as_array()).astype(real_dtype))
try:
rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(0)
Comment on lines +151 to +153

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(0)
r = x.geometry.allocate(None, dtype=read_dtype)

r.fill(np.abs(x.as_array()).astype(real_dtype))
except AttributeError as excp:
rgeo = ImageGeometry(*x.shape[::-1])
r = rgeo.allocate(None, dtype=real_dtype)
r.fill(np.abs(x.asarray()))
fval = func(r, *args, **kwargs)
return fval
return _take_abs_decorator
Expand All @@ -167,12 +172,17 @@ def _abs_project_decorator(self, x: DataContainer, *args, **kwargs):

real_dtype, complex_dtype = _get_real_complex_dtype(x)


rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(None)
r.fill(np.abs(x.as_array()).astype(real_dtype))
Phi = np.exp((1j*np.angle(x.array)))
try:
rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(None)
Comment on lines +176 to +178

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(None)
r = x.geometry.allocate(None, dtype=real_dtype)

r.fill(np.abs(x.as_array()).astype(real_dtype))
Phi = np.exp((1j*np.angle(x.array)))
except AttributeError as excp:
rgeo = ImageGeometry(*x.shape[::-1])
r = rgeo.allocate(None, dtype=real_dtype)
r.fill(np.abs(x.asarray()))
Phi = np.exp((1j*np.angle(x.asarray())))
out = kwargs.pop('out', None)

fvals = func(r, *args, **kwargs)
Expand All @@ -185,9 +195,14 @@ def _abs_project_decorator(self, x: DataContainer, *args, **kwargs):
y = r.copy()
fvals_np = fvals.as_array()
while np.any(fvals_np < 0):
tmp = fvals_np - 0.5*y.as_array() + 0.5*r.as_array()
tmp[tmp < 0] = 0.
y += DataContainter(tmp, y.geometry) - fvals
tmp = r - y
tmp *= 0.5
tmp += fvals
tmparr = tmp.as_array()
tmparr[tmparr < 0] = 0.
tmp.fill(tmparr)
Comment on lines +201 to +203

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check: tmp is a CIL ImageData so this could be simplified as:

Suggested change
tmparr = tmp.as_array()
tmparr[tmparr < 0] = 0.
tmp.fill(tmparr)
tmp.array[tmp.array < 0] = 0.

Or check IndicatorBox

y += tmp
y -= fvals
fvals = func(y, *args, **kwargs)
cts += 1
if cts > 10:
Expand All @@ -196,11 +211,12 @@ def _abs_project_decorator(self, x: DataContainer, *args, **kwargs):
break

if out is None:
out_geom = x.geometry.copy()
out = out_geom.allocate(None)
out = x * 0
if np.isreal(x.as_array()).all():
print("Input is real, returning real output")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("Input is real, returning real output")

out.fill( np.real(fvals_np.astype(complex_dtype)*Phi))
else:
print("Input is complex, returning complex output")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("Input is complex, returning complex output")

out.fill( fvals_np.astype(complex_dtype)*Phi)
return out
return _abs_project_decorator
Expand All @@ -209,7 +225,7 @@ def _abs_project_decorator(self, x: DataContainer, *args, **kwargs):
def _get_real_complex_dtype(x: DataContainer):
'''An internal function to find the type of x and set the corresponding real and complex data types '''

x_dtype = x.as_array().dtype
x_dtype = x.dtype

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmwatson does your data container have a dtype member?

if np.issubdtype(x_dtype, np.complexfloating):
complex_dtype = x_dtype
complex_example = np.array([1 + 1j], dtype=x_dtype)
Expand Down
Loading