Skip to content

Commit

Permalink
BUG: fix issue with axis_labels in slicing and show
Browse files Browse the repository at this point in the history
  • Loading branch information
kohr-h committed Jan 31, 2018
1 parent ac712de commit 23a5ba7
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions odl/discr/lp_discr.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,22 +615,26 @@ def __getitem__(self, indices):
res_part = self.partition[indices_in]
_, collapsed_axes_in, _, _ = simulate_slicing(
self.shape_in, indices_in)
remaining_axes_in = [i for i in range(len(self.shape_in))
remaining_axes_in = [i for i in range(self.ndim_in)
if i not in collapsed_axes_in]
res_part = res_part.byaxis[remaining_axes_in]

# Determine new fspace
sliced_shape_out, _, _, _ = simulate_slicing(
sliced_shape_out, collapsed_axes_out, _, _ = simulate_slicing(
self.shape_out, indices_out)
res_fspace = FunctionSpace(
res_part.set,
out_dtype=(self.fspace.scalar_out_dtype, sliced_shape_out))

remaining_axes_out = [i for i in range(self.ndim_out)
if i not in collapsed_axes_out]
remaining_axes = (remaining_axes_out +
[self.ndim_out + i for i in remaining_axes_in])
# Further attributes for the new space
res_interp = [self.interp_byaxis[i] for i in range(len(self.shape_in))
if i in remaining_axes_in]
res_labels = [self.axis_labels[i] for i in range(len(self.shape_in))
res_interp = [self.interp_byaxis[i] for i in range(self.ndim_in)
if i in remaining_axes_in]
res_labels = [self.axis_labels[i] for i in range(self.ndim)
if i in remaining_axes]

# Create new space
return DiscreteLp(res_fspace, res_part, res_tspace, interp=res_interp,
Expand Down Expand Up @@ -1516,7 +1520,8 @@ def show(self, title=None, method='', coords=None, indices=None,

squeezed_axes = [axis for axis in range(self.ndim)
if is_int(indices[axis])]
axis_labels = [self.space.axis_labels[axis] for axis in squeezed_axes]
axis_labels = [self.space.axis_labels[ax]
for ax in range(self.ndim) if ax not in squeezed_axes]

# Extend partition (trivially) to output axes
part_out = uniform_partition_fromgrid(
Expand Down

0 comments on commit 23a5ba7

Please sign in to comment.