为什么我用 numba 速度不升反降?

2018-04-04 14:25:14 +08:00
 dwjgwsm

看了这篇文章 https://zhuanlan.zhihu.com/p/24168485 试了一下里面的 ma_numba 函数

import time

@numba.jit

def ma_numba(data, ma_length):

ma = []
data_window = data[:ma_length]
test_data = data[ma_length:]

for new_tick in test_data:
    data_window.pop(0)
    data_window.append(new_tick)
    sum_tick = 0
    for tick in data_window:
        sum_tick += tick
    ma.append(sum_tick/ma_length)


a = np.arange(10000)
t1 = time.time()
b = list(a)
bb = ma_numba(b, 5)
t2 = time.time()
print(t2 - t1)


不用 numba,大概耗时 0.03-0.04 秒,用了 numba,耗时 0.7-0.8 秒......奇了怪了,难道是我的姿势不对?
4718 次点击
所在节点    Python
17 条回复
neoblackcap
2018-04-04 14:29:12 +08:00
np 不是本身就是 c 写的吗?你用在这里大概是 jit 也没抵消类型转换啊之类的开销吧。
要不你用个 pyflame 看看哪里开销大?
dwjgwsm
2018-04-04 14:36:28 +08:00
第一,a = np.arange(10000) 这一句是排除在耗时计算之外的.
第二,b = list(a) 这一句是都被计入耗时之内的.所以对比是不存在这个问题的
dwjgwsm
2018-04-04 14:39:06 +08:00
对比时就是简单地把 @numba.jit 这一句注释掉和不注释掉
ipwx
2018-04-04 14:53:24 +08:00
你这代码本来就不科学啊。data_window.pop 你这是想干嘛啊?还有 sum_tick 有你这种写法嘛?好好的 O(n) 算法你给写成 O(n*k) ?

In [1]: import numpy as np

In [2]: import numba

In [3]: def moving_average(data, k):
...: partial_sum = sum(data[:k])
...: ret = [partial_sum / k]
...: for old_d, new_d in zip(data[:-k], data[k:]):
...: partial_sum = partial_sum - old_d + new_d
...: ret.append(partial_sum / k)
...: return ret
...:

In [4]: numba_moving_average = numba.jit(moving_average)

In [5]: arr = np.arange(10000)


In [6]: arr_list = list(arr)

In [7]: %timeit moving_average(arr_liset)

In [8]: %timeit moving_average(arr_list, 5)
3.8 ms ± 9.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [9]: %timeit numba_moving_average(arr_list, 5)
722 µs ± 35.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
dwjgwsm
2018-04-04 15:20:03 +08:00
不对啊,我这结果还是 numba 耗时长啊

t1 = time.time()
b=numba_moving_average(a,5)
t2 = time.time()
c=moving_average(a,5)
t3 = time.time()
print(t2-t1)
print(t3 - t2)

结果:

0.7720441818237305
0.008000373840332031
enenaaa
2018-04-04 15:39:04 +08:00
我也发现了这个情况。numba 和 numpy、cython 混用时耗时不降反升。推测是多种格式数据通过解释器互转效率低下。
ipwx
2018-04-04 15:47:39 +08:00
@dwjgwsm

In [8]: arr_list = list(np.arange(100000))

In [10]: t1 = time.time(); moving_average(arr_list, 5); t2 = time.time(); numba_moving_average(arr_list, 5); t3 = time.time()

In [11]: (t2 - t1, t3 - t2)
Out[11]: (0.0019309520721435547, 0.23806500434875488)

In [12]: t1 = time.time(); moving_average(arr_list, 5); t2 = time.time(); numba_moving_average(arr_list, 5); t3 = time.time()

In [13]: (t2 - t1, t3 - t2)
Out[13]: (0.0016407966613769531, 0.005582094192504883)

In [14]: t1 = time.time(); [moving_average(arr_list, 5) for i in range(100)]; t2 = time.time(); [numba_moving_average(arr_list, 5) for i in range(100)]; t3 = time.time()

In [15]: (t2 - t1, t3 - t2)
Out[15]: (0.18658995628356934, 0.12822914123535156)

In [16]: t1 = time.time(); [moving_average(arr_list, 5) for i in range(1000)]; t2 = time.time(); [numba_moving_average(arr_list, 5) for i in range(1000)]; t3 = time.time()

In [17]: (t2 - t1, t3 - t2)
Out[17]: (1.3983790874481201, 1.3098900318145752)
dwjgwsm
2018-04-04 16:00:42 +08:00
你这个结果也不乐观.看来还是混用不行. 后面再去折腾一下 cython 看看
necomancer
2018-04-04 16:50:22 +08:00
我觉得 7# 说很清楚了吧,一般没有用 time.time() - start 来测试的,除非你程序大概跑在分钟级,data 大个一百万倍再说吧,timeit 是比较合适的测时间的工具。

还有,我想吐槽这个专栏,4# (同一人哎)说得更清楚,这个专栏是来逗比的么……写个移动平均当例子把 o(n) 弄成 o(n*k),这蛋疼的 pop(0)

更吐槽的是,还说第一反应上 NumPy,还 numpy_right ……
为啥不用 np.convolve(data, np.ones(500)/500,mode='valid') 试试?
20.6 ms ± 93.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 是渣渣 i7-3687u 的结果,一样的 data size (100000), 他的 cython 版本说单次时间最快也就 0.0098s 也就是 9.8ms ,这货认真的? numpy 对他来说用法仅限于 a.mean() 和方便的索引了是么……
necomancer
2018-04-04 16:53:57 +08:00
仔细看一下连 cython 里都还有 pop(0)……这个大哥仗着自己是 i7-6700k 就日了天了么……
DSaAAiC
2018-04-04 17:09:16 +08:00
你的代码跑 10000 遍使用了 numba.jit 是 6.175 秒,不使用 numba.jit 是 72.809 秒。numba 的 jit 技术还是起到作用了。
DSaAAiC
2018-04-04 17:09:53 +08:00
@necomancer 大佬都是从哪里知道这些偏僻的 numpy 函数的,系统地看文档吗?
necomancer
2018-04-04 17:23:16 +08:00
顺便再扯一嘴,用 convole 还是一般带窗口的,像这个方窗的情况
```
def maa(data, n):
ret = np.cumsum(data)
ret[n:] = ret[n:] - ret[:-n]
return ret[n-1:]/n
```
渣渣本上也只要 990 微秒。这些不好好看 NumPy 的同学弄得好像 python 咋折腾都很低效……
necomancer
2018-04-04 17:29:08 +08:00
@DSaAAiC 我不是大佬……而且这个问题不能算生僻吧,移动平均,尤其是带有窗口函数的移动平均,遇到得应该还是很多的。我其实看到“移动平均”第一反应是“这其实是个卷积的问题”,当然这么想问题也会复杂化,卷积是 o(n*k),当然一些大窗口体系还能用更快的 fftconvole ……扯远了,知道 NumPy 里都有啥好玩儿的需要一定的数学基础吧,我感觉,把遇到的问题能比较“数学地”进行描述,NumPy/SciPy 总会有惊喜。一般来说都是 Google 一下 问题+scipy 就会看到好玩儿的函数在下面贴着。
liyuanji1002
2018-04-05 02:41:06 +08:00
不知道在哪看的了, 说是 jit 启动需要花费一点时间. 可能你这段代码的计算规模还是低了点~ 试试把规模再翻几十倍看看如何.
dwjgwsm
2018-04-05 10:56:22 +08:00
@liyuanji1002 算了,运算量整太大了,脱离实际需求也没有意义了.反正优化方案里面已经 pass 掉 numba 了
xgdgsc
2018-04-05 11:49:43 +08:00
nopython=True, cache=True 看看

这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。

https://www.v2ex.com/t/444291

V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。

V2EX is a community of developers, designers and creative people.

© 2021 V2EX