问题:NumPy 性能:uint8 与浮点数以及乘法与除法?

我刚刚注意到,仅将乘法更改为除法,我的脚本的执行时间几乎减半。

为了调查这个,我写了一个小例子:

import numpy as np                                                                                                                                                                                
import timeit

# uint8 array
arr1 = np.random.randint(0, high=256, size=(100, 100), dtype=np.uint8)

# float32 array
arr2 = np.random.rand(100, 100).astype(np.float32)
arr2 *= 255.0


def arrmult(a):
    """ 
    mult, read-write iterator
    """
    b = a.copy()
    for item in np.nditer(b, op_flags=["readwrite"]):
        item[...] = (item + 5) * 0.5

def arrmult2(a):
    """ 
    mult, index iterator
    """
    b = a.copy()
    for i, j in np.ndindex(b.shape):
        b[i, j] = (b[i, j] + 5) * 0.5

def arrmult3(a):
    """
    mult, vectorized
    """
    b = a.copy()
    b = (b + 5) * 0.5

def arrdiv(a):
    """ 
    div, read-write iterator 
    """
    b = a.copy()
    for item in np.nditer(b, op_flags=["readwrite"]):
        item[...] = (item + 5) / 2

def arrdiv2(a):
    """ 
    div, index iterator
    """
    b = a.copy()
    for i, j in np.ndindex(b.shape):
           b[i, j] = (b[i, j] + 5)  / 2                                                                                 

def arrdiv3(a):                                                                                                     
    """                                                                                                             
    div, vectorized                                                                                                 
    """                                                                                                             
    b = a.copy()                                                                                                    
    b = (b + 5) / 2                                                                                               




def print_time(name, t):                                                                                            
    print("{: <10}: {: >6.4f}s".format(name, t))                                                                    

timeit_iterations = 100                                                                                             

print("uint8 arrays")                                                                                               
print_time("arrmult", timeit.timeit("arrmult(arr1)", "from __main__ import arrmult, arr1", number=timeit_iterations))
print_time("arrmult2", timeit.timeit("arrmult2(arr1)", "from __main__ import arrmult2, arr1", number=timeit_iterations))
print_time("arrmult3", timeit.timeit("arrmult3(arr1)", "from __main__ import arrmult3, arr1", number=timeit_iterations))
print_time("arrdiv", timeit.timeit("arrdiv(arr1)", "from __main__ import arrdiv, arr1", number=timeit_iterations))  
print_time("arrdiv2", timeit.timeit("arrdiv2(arr1)", "from __main__ import arrdiv2, arr1", number=timeit_iterations))
print_time("arrdiv3", timeit.timeit("arrdiv3(arr1)", "from __main__ import arrdiv3, arr1", number=timeit_iterations))

print("\nfloat32 arrays")                                                                                           
print_time("arrmult", timeit.timeit("arrmult(arr2)", "from __main__ import arrmult, arr2", number=timeit_iterations))
print_time("arrmult2", timeit.timeit("arrmult2(arr2)", "from __main__ import arrmult2, arr2", number=timeit_iterations))
print_time("arrmult3", timeit.timeit("arrmult3(arr2)", "from __main__ import arrmult3, arr2", number=timeit_iterations))
print_time("arrdiv", timeit.timeit("arrdiv(arr2)", "from __main__ import arrdiv, arr2", number=timeit_iterations))  
print_time("arrdiv2", timeit.timeit("arrdiv2(arr2)", "from __main__ import arrdiv2, arr2", number=timeit_iterations))
print_time("arrdiv3", timeit.timeit("arrdiv3(arr2)", "from __main__ import arrdiv3, arr2", number=timeit_iterations))

这将打印以下时间:

uint8 arrays
arrmult   : 2.2004s
arrmult2  : 3.0589s
arrmult3  : 0.0014s
arrdiv    : 1.1540s
arrdiv2   : 2.0780s
arrdiv3   : 0.0027s

float32 arrays
arrmult   : 1.2708s
arrmult2  : 2.4120s
arrmult3  : 0.0009s
arrdiv    : 1.5771s
arrdiv2   : 2.3843s
arrdiv3   : 0.0009s

我一直认为乘法在计算上比除法便宜。然而,对于uint8,一个除法似乎几乎是两倍的效率。这是否与* 0.5必须计算浮点数中的乘法然后将结果转换回整数的事实有关?

至少对于浮点数乘法似乎比除法更快。这通常是真的吗?

为什么uint8中的乘法比float32中的更广泛?我认为 8 位无符号整数的计算速度应该比 32 位浮点数快得多?!

有人可以“揭秘”吗?

编辑:为了获得更多数据,我已经包含了矢量化函数(如建议的那样)并添加了索引迭代器。矢量化函数要快得多,因此不能真正具有可比性。然而,如果向量化函数的timeit_iterations设置得更高,结果证明uint8float32的乘法速度更快。我想这更令人困惑?!

也许乘法实际上总是比除法快,但 for 循环中的主要性能漏洞不是算术运算,而是循环本身。尽管这并不能解释为什么循环对于不同的操作表现不同。

EDIT2:就像@jotasi 已经说过的那样,我们正在寻找divisionmultiplicationint(或uint8)与float(或float32)的完整解释。此外,解释向量化方法和迭代器的不同趋势会很有趣,因为在向量化的情况下,除法似乎更慢,而在迭代器的情况下它更快。

解答

问题在于您的假设,即您测量除法或乘法所需的时间,这是不正确的。您正在测量除法或乘法所需的开销。

人们真的必须查看确切的代码来解释每种效果,这些效果可能因版本而异。这个答案只能给出一个想法,必须考虑什么。

问题是一个简单的int在 python 中一点也不简单:它是一个必须在垃圾收集器中注册的真实对象,它的大小随着它的值而增长——你必须付出的一切:例如 8bit需要整数 24 字节内存! python-floats 也类似。

另一方面,一个 numpy 数组由简单的 c 样式整数/浮点数组成,没有开销,您可以节省大量内存,但在访问 numpy-array 的元素期间会为此付出代价。a[i]意味着:必须构造一个 python 整数,在垃圾收集器中注册,并且只能使用它 - 有很多的开销。

考虑这段代码:

li1=[x%256 for x in xrange(10**4)]
arr1=np.array(li1, np.uint8)

def arrmult(a):    
    for i in xrange(len(a)):
        a[i]*=5;

arrmult(li1)arrmult(arr1)快 25,因为列表中的整数已经是 python-ints 并且不必创建!创建对象需要大部分计算时间 - 几乎可以忽略其他所有内容。


让我们看一下您的代码,首先是乘法:

def arrmult2(a):
    ...
    b[i, j] = (b[i, j] + 5) * 0.5

在 uint8 的情况下,必须发生以下情况(为简单起见,我忽略了 +5):

1.必须创建一个python-int

  1. 必须将其转换为浮点数(python-float 创建),以便能够进行浮点乘法

  2. 并转换回 python-int 或/和 uint8

对于 float32,要做的工作更少(乘法成本不高):1. 创建了一个 python-float 2. 回滚了 float32。

所以浮动版本应该更快,而且确实如此。


现在让我们看一下分区:

def arrdiv2(a):
    ...
    b[i, j] = (b[i, j] + 5)  / 2 

这里的陷阱:所有操作都是整数操作。因此,与乘法相比,无需转换为 python-float,因此与乘法相比,我们的开销更少。在您的情况下,对于 unint8 的除法比乘法“更快”。

但是,对于 float32,除法和乘法同样快/慢,因为在这种情况下几乎没有任何变化——我们仍然需要创建一个 python-float。


现在是矢量化版本:它们与 c 风格的“原始”float32s/uint8s 一起工作,无需转换(及其成本!)到引擎盖下的相应 python 对象。要获得有意义的结果,您应该增加迭代次数(现在运行时间太短,无法肯定地说)。

  1. float32 的除法和乘法可能具有相同的运行时间,因为我希望 numpy 通过乘以0.5来代替除以 2(但要确保必须查看代码)。

  2. uint8 的乘法应该更慢,因为每个 uint8 整数必须在乘以 0.5 之前转换为浮点数,然后再转换回 uint8。

3.对于uint8的情况,numpy不能通过乘以0.5来代替除以2,因为它是整数除法。对于许多架构,整数除法比浮点乘法慢 - 这是最慢的矢量化操作。


PS:我不会过多地谈论成本乘法与除法 - 有太多其他事情会对性能产生更大的影响。例如创建不必要的临时对象,或者如果 numpy-array 很大并且不适合缓存,那么内存访问将成为瓶颈 - 您将根本看不到乘法和除法之间的区别。

Logo

学AI,认准AI Studio!GPU算力,限时免费领,邀请好友解锁更多惊喜福利 >>>

更多推荐