Skip to content

Commit

Permalink
ENH: complete and improve byaxis in fspace and DiscreteLp
Browse files Browse the repository at this point in the history
  • Loading branch information
kohr-h committed Jan 31, 2018
1 parent 23a5ba7 commit f3d8880
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 32 deletions.
136 changes: 118 additions & 18 deletions odl/discr/lp_discr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
is_floating_dtype, is_numeric_dtype, is_int,
dtype_str, array_str, signature_string, indent, npy_printoptions,
normalized_scalar_param_list, safe_int_conv, normalized_nodes_on_bdry,
normalized_index_expression, simulate_slicing)
normalized_index_expression, simulate_slicing, normalized_axis_indices)

__all__ = ('DiscreteLp', 'DiscreteLpElement',
'uniform_discr_frompartition', 'uniform_discr_fromspace',
Expand Down Expand Up @@ -640,6 +640,119 @@ def __getitem__(self, indices):
return DiscreteLp(res_fspace, res_part, res_tspace, interp=res_interp,
axis_labels=res_labels)

@property
def byaxis(self):
"""Object to index along (input and output) axes.
Examples
--------
Indexing with integers or slices:
>>> space = odl.uniform_discr([0, 0, 0], [1, 2, 3], (5, 10, 15),
... dtype=(float, (2, 3)))
>>> space.byaxis[0]
uniform_discr(0.0, 1.0, 5)
>>> space.byaxis_in[1]
uniform_discr(0.0, 2.0, 10)
>>> space.byaxis_in[1:]
uniform_discr([ 0., 0.], [ 2., 3.], (10, 15))
Lists can be used to stack spaces arbitrarily:
>>> space.byaxis_in[[2, 1, 2]]
uniform_discr([ 0., 0., 0.], [ 3., 2., 3.], (15, 10, 15))
"""
space = self

class DiscreteLpByaxis(object):

"""Helper class for indexing by axes."""

def __getitem__(self, indices):
"""Return ``self[indices]``.
Parameters
----------
indices : index expression
Object used to index the space.
Returns
-------
space : `DiscreteLp`
The resulting space after indexing along axes, with
roughly the same properties except possibly weighting.
"""
indices = normalized_axis_indices(indices, space.ndim)
idcs_in = [i - space.ndim_out for i in indices
if i >= space.ndim_out]

fspace = space.fspace.byaxis[indices]
part = space.partition.byaxis[idcs_in]
tspace = space.tspace.byaxis[indices]

interp = tuple(space.interp_byaxis[int(i)]
for i in idcs_in)
labels = tuple(space.axis_labels[int(i)]
for i in indices)

return DiscreteLp(fspace, part, tspace, interp,
axis_labels=labels)

def __repr__(self):
"""Return ``repr(self)``."""
return repr(space) + '.byaxis'

return DiscreteLpByaxis()

@property
def byaxis_out(self):
"""Object to index along output (tensor component) dimensions.
Examples
--------
Indexing with integers or slices:
>>> space = odl.uniform_discr(0, 1, 5, dtype=(float, (2, 3, 4)))
>>> space.byaxis_out[0]
uniform_discr(0.0, 1.0, 5, dtype=('float64', (2,)))
>>> space.byaxis_out[1:]
uniform_discr(0.0, 1.0, 5, dtype=('float64', (3, 4)))
Lists can be used to stack spaces arbitrarily:
>>> space.byaxis_out[[2, 1, 2]]
uniform_discr(0.0, 1.0, 5, dtype=('float64', (4, 3, 4)))
"""
space = self

class DiscreteLpByaxisOut(object):

"""Helper class for indexing by output axes."""

def __getitem__(self, indices):
"""Return ``self[indices]``.
Parameters
----------
indices : index expression
Object used to index the output axes.
Returns
-------
space : `DiscreteLp`
The resulting space with indexed output components and
otherwise same properties (except possibly weighting).
"""
idcs_out = normalized_axis_indices(indices, space.ndim_out)
idcs_in = tuple(range(space.ndim_out, space.ndim))
return space.byaxis[idcs_out + idcs_in]

def __repr__(self):
"""Return ``repr(self)``."""
return repr(space) + '.byaxis_out'

return DiscreteLpByaxisOut()

@property
def byaxis_in(self):
"""Object to index along input (domain) dimensions.
Expand Down Expand Up @@ -681,23 +794,10 @@ def __getitem__(self, indices):
The resulting space with indexed domain and otherwise
same properties (except possibly weighting).
"""
fspace = space.fspace.byaxis_in[indices]
part = space.partition.byaxis[indices]
tspace = space.tspace.byaxis[indices]

try:
iter(indices)
except TypeError:
interp = space.interp_byaxis[indices]
labels = space.axis_labels[indices]
else:
interp = tuple(space.interp_byaxis[int(i)]
for i in indices)
labels = tuple(space.axis_labels[int(i)]
for i in indices)

return DiscreteLp(fspace, part, tspace, interp,
axis_labels=labels)
indices = normalized_axis_indices(indices, space.ndim_in)
idcs_out = list(range(space.ndim_out))
idcs_in = [i + space.ndim_out for i in indices]
return space.byaxis[idcs_out + idcs_in]

def __repr__(self):
"""Return ``repr(self)``."""
Expand Down
96 changes: 85 additions & 11 deletions odl/space/fspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
is_real_dtype, is_complex_floating_dtype, dtype_repr, dtype_str,
complex_dtype, real_dtype, signature_string,
is_valid_input_array, is_valid_input_meshgrid,
out_shape_from_array, out_shape_from_meshgrid, vectorize, broadcast_to,
writable_array)
out_shape_from_array, out_shape_from_meshgrid, vectorize,
writable_array, is_int, normalized_axis_indices)
from odl.util.npy_compat import broadcast_to
from odl.util.utility import preload_first_arg, getargspec


Expand Down Expand Up @@ -869,6 +870,75 @@ def f_conj(x, **kwargs):
else:
return self.element(f_conj)

@property
def byaxis(self):
"""Object to index along output and input dimensions.
Note that the output dimensions come first, such that a ``F[0]``
for a vector-valued function gives the first vector component
function.
See Also
--------
byaxis_out : index along output axes only
byaxis_in : index along input axes only
Examples
--------
Indexing with integers or slices:
>>> domain = odl.IntervalProd(0, 1)
>>> fspace = odl.FunctionSpace(domain, out_dtype=(float, (2, 3, 4)))
>>> fspace.byaxis[1:]
FunctionSpace(IntervalProd(0.0, 1.0), out_dtype=('float64', (3, 4)))
"""
space = self

class FspaceByaxis(object):

"""Helper class for indexing by axes."""

def __getitem__(self, indices):
"""Return ``self[indices]``.
Parameters
----------
indices : index expression
Object used to select axes. This can be either an int,
slice or sequence of integers.
Returns
-------
space : `FunctionSpace`
The resulting space with roughly the same properties.
"""
ndim_out = len(space.out_shape)
ndim = ndim_out + space.domain.ndim

if is_int(indices):
indices = [indices]
elif isinstance(indices, slice):
indices = list(range(ndim))[indices]

if any(not is_int(i) for i in indices):
raise TypeError('sequence may only contain integers, '
'got {}'.format(indices))

indices = [i + ndim if i < 0 else i for i in indices]
idcs_out = [i for i in indices if i < ndim_out]
idcs_in = [i - ndim_out for i in indices if i >= ndim_out]

domain = space.domain[idcs_in]
out_shape = tuple(space.out_shape[i] for i in idcs_out)
out_dtype = (space.scalar_out_dtype, out_shape)
return FunctionSpace(domain, out_dtype)

def __repr__(self):
"""Return ``repr(self)``."""
return repr(space) + '.byaxis'

return FspaceByaxis()

@property
def byaxis_out(self):
"""Object to index along output dimensions.
Expand All @@ -892,6 +962,10 @@ def byaxis_out(self):
>>> fspace.byaxis_out[[2, 1, 2]]
FunctionSpace(IntervalProd(0.0, 1.0), out_dtype=('float64', (4, 3, 4)))
See Also
--------
byaxis_in : index along input axes
"""
space = self

Expand All @@ -918,15 +992,11 @@ def __getitem__(self, indices):
IndexError
If this is a space of scalar-valued functions.
"""
try:
iter(indices)
except TypeError:
newshape = space.out_shape[indices]
else:
newshape = tuple(space.out_shape[int(i)] for i in indices)

dtype = (space.scalar_out_dtype, newshape)
return FunctionSpace(space.domain, out_dtype=dtype)
ndim_out = len(space.out_shape)
ndim = ndim_out + space.domain.ndim
idcs_out = normalized_axis_indices(indices, ndim_out)
idcs_in = tuple(range(ndim_out, ndim))
return space.byaxis[idcs_out + idcs_in]

def __repr__(self):
"""Return ``repr(self)``."""
Expand Down Expand Up @@ -955,6 +1025,10 @@ def byaxis_in(self):
>>> fspace.byaxis_in[[2, 1, 2]]
FunctionSpace(IntervalProd([ 0., 0., 0.], [ 3., 2., 3.]))
See Also
--------
byaxis_out : index along output axes
"""
space = self

Expand Down
49 changes: 46 additions & 3 deletions odl/util/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

__all__ = ('normalized_scalar_param_list', 'normalized_index_expression',
'normalized_nodes_on_bdry', 'normalized_axes_tuple',
'safe_int_conv')
'normalized_axis_indices', 'safe_int_conv')


def normalized_scalar_param_list(param, length, param_conv=None,
Expand Down Expand Up @@ -352,6 +352,49 @@ def normalized_nodes_on_bdry(nodes_on_bdry, length):
return out_list


def normalized_axis_indices(indices, ndim):
"""Turn the given indices into a tuple of indices in the valid range.
This helper is intended for index normalization when indexing along
an object with ``ndim`` axes along the axes, e.g., in `DiscreteLp.byaxis`.
A slice is turned into a corresponding tuple of integers, and
a single integer is wrapped into a tuple. Negative indices are
incremented by ``ndim``.
Parameters
----------
indices : slice, int or sequence of int
Object for indexing along axes.
ndim : positive int
Number of available axes determining the valid axis range.
Returns
-------
norm_idcs : tuple of int
The normalized indices that all satisfy ``0 <= i < ndim``.
Raises
------
ValueError
If a given sequence contains non-integers.
"""
indices_in = indices

if isinstance(indices, slice):
indices = list(range(ndim))[indices]

try:
iter(indices)
except TypeError:
indices = [indices]

if any(not is_int(i) for i in indices):
raise ValueError('only slice, int or sequence of int is allowed, '
'got {}'.format(indices_in))

return tuple(i + ndim if i < 0 else i for i in indices)


def normalized_axes_tuple(axes, ndim):
"""Return a tuple of ``axes`` converted to positive integers.
Expand All @@ -360,7 +403,7 @@ def normalized_axes_tuple(axes, ndim):
Parameters
----------
axes : int or sequence of ints
axes : int or sequence of int
Single integer or integer sequence of arbitrary length.
Duplicate entries are not allowed. All entries must fulfill
``-ndim <= axis <= ndim - 1``.
Expand All @@ -369,7 +412,7 @@ def normalized_axes_tuple(axes, ndim):
Returns
-------
axes_list : tuple of ints
axes_list : tuple of int
The converted tuple of axes.
Examples
Expand Down

0 comments on commit f3d8880

Please sign in to comment.