Answer a question

I need to shift a 3D array by a 3D vector of displacement for an algorithm. As of now I'm using this (admitedly very ugly) method :

shiftedArray = np.roll(np.roll(np.roll(arrayToShift, shift[0], axis=0)
                                     , shift[1], axis=1),
                             shift[2], axis=2)  

Which works, but means I'm calling 3 rolls ! (58% of my algorithm time is spent in these, according to my profiling)

From the docs of Numpy.roll:

Parameters:
shift : int

axis : int, optional

No mention of array-like in parameter ... So I can't have a multidimensional rolling ?

I thought I could just call a this kind of function (sounds like a Numpy thing to do) :

np.roll(arrayToShift,3DshiftVector,axis=(0,1,2))

Maybe with a flattened version of my array reshaped ? but then how do I compute the shift vector ? and is this shift really the same ?

I'm surprised to find no easy solution for this, as I thought this would be a pretty common thing to do (okay, not that common, but ...)

So how do we --relatively-- efficiently shift a ndarray by a N-Dimensional vector ?


Note: This question was asked in 2015, back when numpy's roll method did not support this feature.

Answers

In theory, using scipy.ndimage.interpolation.shift as described by @Ed Smith should work, but because of a bug (https://github.com/scipy/scipy/issues/1323), it didn't give a result that is equivalent to multiple calls of np.roll.


UPDATE: "Multi-roll" capability was added to numpy.roll in numpy version 1.12.0. Here's a two-dimensional example, in which the first axis is rolled one position and the second axis is rolled three positions:

In [7]: x = np.arange(20).reshape(4,5)

In [8]: x
Out[8]: 
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])

In [9]: numpy.roll(x, [1, 3], axis=(0, 1))
Out[9]: 
array([[17, 18, 19, 15, 16],
       [ 2,  3,  4,  0,  1],
       [ 7,  8,  9,  5,  6],
       [12, 13, 14, 10, 11]])

This makes the code below obsolete. I'll leave it there for posterity.


The code below defines a function I call multiroll that does what you want. Here's an example in which it is applied to an array with shape (500, 500, 500):

In [64]: x = np.random.randn(500, 500, 500)

In [65]: shift = [10, 15, 20]

Use multiple calls to np.roll to generate the expected result:

In [66]: yroll3 = np.roll(np.roll(np.roll(x, shift[0], axis=0), shift[1], axis=1), shift[2], axis=2)

Generate the shifted array using multiroll:

In [67]: ymulti = multiroll(x, shift)

Verify that we got the expected result:

In [68]: np.all(yroll3 == ymulti)
Out[68]: True

For an array this size, making three calls to np.roll is almost three times slower than a call to multiroll:

In [69]: %timeit yroll3 = np.roll(np.roll(np.roll(x, shift[0], axis=0), shift[1], axis=1), shift[2], axis=2)
1 loops, best of 3: 1.34 s per loop

In [70]: %timeit ymulti = multiroll(x, shift)
1 loops, best of 3: 474 ms per loop

Here's the definition of multiroll:

from itertools import product
import numpy as np


def multiroll(x, shift, axis=None):
    """Roll an array along each axis.

    Parameters
    ----------
    x : array_like
        Array to be rolled.
    shift : sequence of int
        Number of indices by which to shift each axis.
    axis : sequence of int, optional
        The axes to be rolled.  If not given, all axes is assumed, and
        len(shift) must equal the number of dimensions of x.

    Returns
    -------
    y : numpy array, with the same type and size as x
        The rolled array.

    Notes
    -----
    The length of x along each axis must be positive.  The function
    does not handle arrays that have axes with length 0.

    See Also
    --------
    numpy.roll

    Example
    -------
    Here's a two-dimensional array:

    >>> x = np.arange(20).reshape(4,5)
    >>> x 
    array([[ 0,  1,  2,  3,  4],
           [ 5,  6,  7,  8,  9],
           [10, 11, 12, 13, 14],
           [15, 16, 17, 18, 19]])

    Roll the first axis one step and the second axis three steps:

    >>> multiroll(x, [1, 3])
    array([[17, 18, 19, 15, 16],
           [ 2,  3,  4,  0,  1],
           [ 7,  8,  9,  5,  6],
           [12, 13, 14, 10, 11]])

    That's equivalent to:

    >>> np.roll(np.roll(x, 1, axis=0), 3, axis=1)
    array([[17, 18, 19, 15, 16],
           [ 2,  3,  4,  0,  1],
           [ 7,  8,  9,  5,  6],
           [12, 13, 14, 10, 11]])

    Not all the axes must be rolled.  The following uses
    the `axis` argument to roll just the second axis:

    >>> multiroll(x, [2], axis=[1])
    array([[ 3,  4,  0,  1,  2],
           [ 8,  9,  5,  6,  7],
           [13, 14, 10, 11, 12],
           [18, 19, 15, 16, 17]])

    which is equivalent to:

    >>> np.roll(x, 2, axis=1)
    array([[ 3,  4,  0,  1,  2],
           [ 8,  9,  5,  6,  7],
           [13, 14, 10, 11, 12],
           [18, 19, 15, 16, 17]])

    """
    x = np.asarray(x)
    if axis is None:
        if len(shift) != x.ndim:
            raise ValueError("The array has %d axes, but len(shift) is only "
                             "%d. When 'axis' is not given, a shift must be "
                             "provided for all axes." % (x.ndim, len(shift)))
        axis = range(x.ndim)
    else:
        # axis does not have to contain all the axes.  Here we append the
        # missing axes to axis, and for each missing axis, append 0 to shift.
        missing_axes = set(range(x.ndim)) - set(axis)
        num_missing = len(missing_axes)
        axis = tuple(axis) + tuple(missing_axes)
        shift = tuple(shift) + (0,)*num_missing

    # Use mod to convert all shifts to be values between 0 and the length
    # of the corresponding axis.
    shift = [s % x.shape[ax] for s, ax in zip(shift, axis)]

    # Reorder the values in shift to correspond to axes 0, 1, ..., x.ndim-1.
    shift = np.take(shift, np.argsort(axis))

    # Create the output array, and copy the shifted blocks from x to y.
    y = np.empty_like(x)
    src_slices = [(slice(n-shft, n), slice(0, n-shft))
                  for shft, n in zip(shift, x.shape)]
    dst_slices = [(slice(0, shft), slice(shft, n))
                  for shft, n in zip(shift, x.shape)]
    src_blks = product(*src_slices)
    dst_blks = product(*dst_slices)
    for src_blk, dst_blk in zip(src_blks, dst_blks):
        y[dst_blk] = x[src_blk]

    return y
Logo

学AI,认准AI Studio!GPU算力,限时免费领,邀请好友解锁更多惊喜福利 >>>

更多推荐