diff --git a/odl/discr/lp_discr.py b/odl/discr/lp_discr.py index f564b134c5a..e39850bf126 100644 --- a/odl/discr/lp_discr.py +++ b/odl/discr/lp_discr.py @@ -30,7 +30,7 @@ is_floating_dtype, is_numeric_dtype, is_int, dtype_str, array_str, signature_string, indent, npy_printoptions, normalized_scalar_param_list, safe_int_conv, normalized_nodes_on_bdry, - normalized_index_expression, simulate_slicing) + normalized_index_expression, simulate_slicing, normalized_axis_indices) __all__ = ('DiscreteLp', 'DiscreteLpElement', 'uniform_discr_frompartition', 'uniform_discr_fromspace', @@ -640,6 +640,119 @@ def __getitem__(self, indices): return DiscreteLp(res_fspace, res_part, res_tspace, interp=res_interp, axis_labels=res_labels) + @property + def byaxis(self): + """Object to index along (input and output) axes. + + Examples + -------- + Indexing with integers or slices: + + >>> space = odl.uniform_discr([0, 0, 0], [1, 2, 3], (5, 10, 15), + ... dtype=(float, (2, 3))) + >>> space.byaxis[0] + uniform_discr(0.0, 1.0, 5) + >>> space.byaxis_in[1] + uniform_discr(0.0, 2.0, 10) + >>> space.byaxis_in[1:] + uniform_discr([ 0., 0.], [ 2., 3.], (10, 15)) + + Lists can be used to stack spaces arbitrarily: + + >>> space.byaxis_in[[2, 1, 2]] + uniform_discr([ 0., 0., 0.], [ 3., 2., 3.], (15, 10, 15)) + """ + space = self + + class DiscreteLpByaxis(object): + + """Helper class for indexing by axes.""" + + def __getitem__(self, indices): + """Return ``self[indices]``. + + Parameters + ---------- + indices : index expression + Object used to index the space. + + Returns + ------- + space : `DiscreteLp` + The resulting space after indexing along axes, with + roughly the same properties except possibly weighting. + """ + indices = normalized_axis_indices(indices, space.ndim) + idcs_in = [i - space.ndim_out for i in indices + if i >= space.ndim_out] + + fspace = space.fspace.byaxis[indices] + part = space.partition.byaxis[idcs_in] + tspace = space.tspace.byaxis[indices] + + interp = tuple(space.interp_byaxis[int(i)] + for i in idcs_in) + labels = tuple(space.axis_labels[int(i)] + for i in indices) + + return DiscreteLp(fspace, part, tspace, interp, + axis_labels=labels) + + def __repr__(self): + """Return ``repr(self)``.""" + return repr(space) + '.byaxis' + + return DiscreteLpByaxis() + + @property + def byaxis_out(self): + """Object to index along output (tensor component) dimensions. + + Examples + -------- + Indexing with integers or slices: + + >>> space = odl.uniform_discr(0, 1, 5, dtype=(float, (2, 3, 4))) + >>> space.byaxis_out[0] + uniform_discr(0.0, 1.0, 5, dtype=('float64', (2,))) + >>> space.byaxis_out[1:] + uniform_discr(0.0, 1.0, 5, dtype=('float64', (3, 4))) + + Lists can be used to stack spaces arbitrarily: + + >>> space.byaxis_out[[2, 1, 2]] + uniform_discr(0.0, 1.0, 5, dtype=('float64', (4, 3, 4))) + """ + space = self + + class DiscreteLpByaxisOut(object): + + """Helper class for indexing by output axes.""" + + def __getitem__(self, indices): + """Return ``self[indices]``. + + Parameters + ---------- + indices : index expression + Object used to index the output axes. + + Returns + ------- + space : `DiscreteLp` + The resulting space with indexed output components and + otherwise same properties (except possibly weighting). + """ + idcs_out = normalized_axis_indices(indices, space.ndim_out) + idcs_in = tuple(range(space.ndim_out, space.ndim)) + return space.byaxis[idcs_out + idcs_in] + + def __repr__(self): + """Return ``repr(self)``.""" + return repr(space) + '.byaxis_out' + + return DiscreteLpByaxisOut() + @property def byaxis_in(self): """Object to index along input (domain) dimensions. @@ -681,23 +794,10 @@ def __getitem__(self, indices): The resulting space with indexed domain and otherwise same properties (except possibly weighting). """ - fspace = space.fspace.byaxis_in[indices] - part = space.partition.byaxis[indices] - tspace = space.tspace.byaxis[indices] - - try: - iter(indices) - except TypeError: - interp = space.interp_byaxis[indices] - labels = space.axis_labels[indices] - else: - interp = tuple(space.interp_byaxis[int(i)] - for i in indices) - labels = tuple(space.axis_labels[int(i)] - for i in indices) - - return DiscreteLp(fspace, part, tspace, interp, - axis_labels=labels) + indices = normalized_axis_indices(indices, space.ndim_in) + idcs_out = list(range(space.ndim_out)) + idcs_in = [i + space.ndim_out for i in indices] + return space.byaxis[idcs_out + idcs_in] def __repr__(self): """Return ``repr(self)``.""" diff --git a/odl/space/fspace.py b/odl/space/fspace.py index 720624ff6d3..829c774fc6e 100644 --- a/odl/space/fspace.py +++ b/odl/space/fspace.py @@ -21,8 +21,9 @@ is_real_dtype, is_complex_floating_dtype, dtype_repr, dtype_str, complex_dtype, real_dtype, signature_string, is_valid_input_array, is_valid_input_meshgrid, - out_shape_from_array, out_shape_from_meshgrid, vectorize, broadcast_to, - writable_array) + out_shape_from_array, out_shape_from_meshgrid, vectorize, + writable_array, is_int, normalized_axis_indices) +from odl.util.npy_compat import broadcast_to from odl.util.utility import preload_first_arg, getargspec @@ -869,6 +870,75 @@ def f_conj(x, **kwargs): else: return self.element(f_conj) + @property + def byaxis(self): + """Object to index along output and input dimensions. + + Note that the output dimensions come first, such that a ``F[0]`` + for a vector-valued function gives the first vector component + function. + + See Also + -------- + byaxis_out : index along output axes only + byaxis_in : index along input axes only + + Examples + -------- + Indexing with integers or slices: + + >>> domain = odl.IntervalProd(0, 1) + >>> fspace = odl.FunctionSpace(domain, out_dtype=(float, (2, 3, 4))) + >>> fspace.byaxis[1:] + FunctionSpace(IntervalProd(0.0, 1.0), out_dtype=('float64', (3, 4))) + """ + space = self + + class FspaceByaxis(object): + + """Helper class for indexing by axes.""" + + def __getitem__(self, indices): + """Return ``self[indices]``. + + Parameters + ---------- + indices : index expression + Object used to select axes. This can be either an int, + slice or sequence of integers. + + Returns + ------- + space : `FunctionSpace` + The resulting space with roughly the same properties. + """ + ndim_out = len(space.out_shape) + ndim = ndim_out + space.domain.ndim + + if is_int(indices): + indices = [indices] + elif isinstance(indices, slice): + indices = list(range(ndim))[indices] + + if any(not is_int(i) for i in indices): + raise TypeError('sequence may only contain integers, ' + 'got {}'.format(indices)) + + indices = [i + ndim if i < 0 else i for i in indices] + idcs_out = [i for i in indices if i < ndim_out] + idcs_in = [i - ndim_out for i in indices if i >= ndim_out] + + domain = space.domain[idcs_in] + out_shape = tuple(space.out_shape[i] for i in idcs_out) + out_dtype = (space.scalar_out_dtype, out_shape) + return FunctionSpace(domain, out_dtype) + + def __repr__(self): + """Return ``repr(self)``.""" + return repr(space) + '.byaxis' + + return FspaceByaxis() + @property def byaxis_out(self): """Object to index along output dimensions. @@ -892,6 +962,10 @@ def byaxis_out(self): >>> fspace.byaxis_out[[2, 1, 2]] FunctionSpace(IntervalProd(0.0, 1.0), out_dtype=('float64', (4, 3, 4))) + + See Also + -------- + byaxis_in : index along input axes """ space = self @@ -918,15 +992,11 @@ def __getitem__(self, indices): IndexError If this is a space of scalar-valued functions. """ - try: - iter(indices) - except TypeError: - newshape = space.out_shape[indices] - else: - newshape = tuple(space.out_shape[int(i)] for i in indices) - - dtype = (space.scalar_out_dtype, newshape) - return FunctionSpace(space.domain, out_dtype=dtype) + ndim_out = len(space.out_shape) + ndim = ndim_out + space.domain.ndim + idcs_out = normalized_axis_indices(indices, ndim_out) + idcs_in = tuple(range(ndim_out, ndim)) + return space.byaxis[idcs_out + idcs_in] def __repr__(self): """Return ``repr(self)``.""" @@ -955,6 +1025,10 @@ def byaxis_in(self): >>> fspace.byaxis_in[[2, 1, 2]] FunctionSpace(IntervalProd([ 0., 0., 0.], [ 3., 2., 3.])) + + See Also + -------- + byaxis_out : index along output axes """ space = self diff --git a/odl/util/normalize.py b/odl/util/normalize.py index 420d8475536..6e1ba0ad277 100644 --- a/odl/util/normalize.py +++ b/odl/util/normalize.py @@ -16,7 +16,7 @@ __all__ = ('normalized_scalar_param_list', 'normalized_index_expression', 'normalized_nodes_on_bdry', 'normalized_axes_tuple', - 'safe_int_conv') + 'normalized_axis_indices', 'safe_int_conv') def normalized_scalar_param_list(param, length, param_conv=None, @@ -352,6 +352,49 @@ def normalized_nodes_on_bdry(nodes_on_bdry, length): return out_list +def normalized_axis_indices(indices, ndim): + """Turn the given indices into a tuple of indices in the valid range. + + This helper is intended for index normalization when indexing along + an object with ``ndim`` axes along the axes, e.g., in `DiscreteLp.byaxis`. + A slice is turned into a corresponding tuple of integers, and + a single integer is wrapped into a tuple. Negative indices are + incremented by ``ndim``. + + Parameters + ---------- + indices : slice, int or sequence of int + Object for indexing along axes. + ndim : positive int + Number of available axes determining the valid axis range. + + Returns + ------- + norm_idcs : tuple of int + The normalized indices that all satisfy ``0 <= i < ndim``. + + Raises + ------ + ValueError + If a given sequence contains non-integers. + """ + indices_in = indices + + if isinstance(indices, slice): + indices = list(range(ndim))[indices] + + try: + iter(indices) + except TypeError: + indices = [indices] + + if any(not is_int(i) for i in indices): + raise ValueError('only slice, int or sequence of int is allowed, ' + 'got {}'.format(indices_in)) + + return tuple(i + ndim if i < 0 else i for i in indices) + + def normalized_axes_tuple(axes, ndim): """Return a tuple of ``axes`` converted to positive integers. @@ -360,7 +403,7 @@ def normalized_axes_tuple(axes, ndim): Parameters ---------- - axes : int or sequence of ints + axes : int or sequence of int Single integer or integer sequence of arbitrary length. Duplicate entries are not allowed. All entries must fulfill ``-ndim <= axis <= ndim - 1``. @@ -369,7 +412,7 @@ def normalized_axes_tuple(axes, ndim): Returns ------- - axes_list : tuple of ints + axes_list : tuple of int The converted tuple of axes. Examples