Numba はいいぞ
この記事は何
ふつうの Python なら動くけど Numba では動かないようなコードを列挙して、対処法を書いたもの
主に AtCoder 目的だけどそれ以外でも役に立つはず
Numba のバージョン 0.48.0 くらいの情報なので将来的にいろいろ変わってくると思うので注意(2020 年 8 月現在で AtCoder に入ってるのも 0.48.0)
先に読んでおくといいかもしれない記事
qiita.com
ikatakos.com
Numba で使えないもの
2 次元以上の ndarray のイテレーション
できない
エラーになるコード
@numba.njit("void()", cache=True)
def solve():
array = np.random.rand(5, 2)
for a in array:
...
エラーメッセージ
Direct iteration is not supported for arrays with dimension > 1. Try using indexing instead.
[1] During: typing of intrinsic-call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)
File "untitled.py", line 7:
def solve():
<source elided>
array = np.random.rand(5, 2) # 5x2 array
for a in array: # コンパイルエラー
^
対処
range で回す
@numba.njit("void()", cache=True)
def solve():
array = np.random.rand(5, 2)
for i in range(len(array)):
a = array[i]
...
変な(?)方法での空のリスト作成
Numba が型を推測できないとエラーが出る
list 以外に dict と set でも起こる
結構ありがちでエラーメッセージもわかりにくかったりするので、とりあえず型を明示しておくのがいいかもしれない
エラーになるコード
@numba.njit("void()", cache=True)
def solve():
lst = [[] for _ in range(10)]
lst[0].append(0)
エラーメッセージ
Undecided type $26load_method.11 := <undecided>
[1] During: resolving caller type: $26load_method.11
[2] During: typing of call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)
File "untitled.py", line 7:
def solve():
<source elided>
lst = [[] for _ in range(10)] # コンパイルエラー
lst[0].append(0)
^
対処
どうにかして型を Numba に教える
@numba.njit("void()", cache=True)
def solve():
lst = [[0] * 0 for _ in range(10)]
lst[0].append(0)
dict の場合はいらない要素を入れておくとか
set の場合は {0}-{0} とかすれば Numba くんはわかってくれる
辞書の内包表記
dict と set は内包表記で生成できない
あと dict はリストからの生成とかもできない
エラーになるコード
@numba.njit("void()", cache=True)
def solve():
fib = [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89]
inv_fib = {v: i for i, v in enumerate(fib)}
エラーメッセージ
Use of unsupported opcode (MAP_ADD) found
File "untitled.py", line 7:
def solve():
<source elided>
fib = [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89]
inv_fib = {v: i for i, v in enumerate(fib)} # コンパイルエラー
^
対処
ひとつずつ入れる
@numba.njit("void()", cache=True)
def solve():
fib = [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89]
inv_fib = {}
for i, v in enumerate(fib):
inv_fib[v] = i
リストを値に取る辞書
エラーになるコード
@numba.njit("void()", cache=True)
def solve():
dictionary = {3023: [0, 1, 2], 4006: [3, 4, 5]}
d = dictionary[3023]
エラーメッセージ
list(int64) as value is forbidden
[1] During: typing of dict at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)
File "untitled.py", line 7:
def solve():
dictionary = {3023: [0, 1, 2], 4006: [3, 4, 5]} # コンパイルエラー
^
対処
numpy.ndarray でもいいならそれを使う
そうでないなら、値は別にリストか何かで持っておいて、辞書にはそのインデックスを入れる
@numba.njit("void()", cache=True)
def solve():
dictionary = {3023: 0, 4006: 1}
container = [[0, 1, 2], [3, 4, 5]]
d = container[dictionary[3023]]
pow の第 3 引数
使えない
エラーになるコード
@numba.njit("void()", cache=True)
def solve():
mod = 10 ** 9 + 7
inv10 = pow(10, mod-2, mod)
エラーメッセージ
Invalid use of Function(<built-in function pow>) with argument(s) of type(s): (Literal[int](10), int64, Literal[int](1000000007))
Known signatures:
* (int64, int64) -> int64
* (int64, uint64) -> int64
* (uint64, int64) -> int64
* (uint64, uint64) -> uint64
* (float32, int32) -> float32
* (float32, int64) -> float32
* (float32, uint64) -> float32
* (float64, int32) -> float64
* (float64, int64) -> float64
* (float64, uint64) -> float64
* (float32, float32) -> float32
* (float64, float64) -> float64
* (complex64, complex64) -> complex64
* (complex128, complex128) -> complex128
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function pow>)
[2] During: typing of call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)
File "untitled.py", line 7:
def solve():
<source elided>
mod = 10 ** 9 + 7
inv10 = pow(10, mod-2, mod) # コンパイルエラー
^
対処
代わりのものを作っておく
@numba.njit("i8(i8,i8,i8)", cache=True)
def pow_mod(base, exp, mod):
exp %= mod - 1
res = 1
while exp:
if exp & 1:
res = res * base % mod
base = base * base % mod
exp >>= 1
return res
@numba.njit("void()", cache=True)
def solve():
mod = 10 ** 9 + 7
inv10 = pow_mod(10, mod-2, mod)
built-in の sum 関数
max は使えるのに sum は何故か使えない
エラーになるコード
@numba.njit("void()", cache=True)
def solve():
a = np.random.rand(5)
s = sum(a)
エラーメッセージ
Untyped global name 'sum': cannot determine Numba type of <class 'builtin_function_or_method'>
File "untitled.py", line 7:
def solve():
<source elided>
a = np.random.rand(5)
s = sum(a) # コンパイルエラー
^
対処
numpy.sum か numpy.ndarray.sum を使う
リストの場合は numpy.sum でもエラーになるのでそのときは numpy.ndarray に変換するとかひとつずつ足すとかする
@numba.njit("void()", cache=True)
def solve():
a = np.random.rand(5)
s = np.sum(a)
numpy.max とか numpy.ndarray.max とかの axis
numpy.sum は axis が使えるのに numpy.max は axis が使えない
エラーになるコード
@numba.njit("void()", cache=True)
def solve():
array = np.random.rand(4, 5)
m = array.max(1)
エラーメッセージ
[1] During: resolving callee type: BoundFunction(array.max for array(float64, 2d, C))
[2] During: typing of call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)
Enable logging at debug level for details.
File "untitled.py", line 7:
def solve():
<source elided>
array = np.random.rand(4, 5)
m = array.max(1)
^
対処
for を回す
@numba.njit("void()", cache=True)
def solve():
array = np.random.rand(4, 5)
m = np.empty(4, dtype=array.dtype)
for i in range(4):
m[i] = array[i].max()
2 次元以上の ndarray の boolean indexing
できない
エラーになるコード
@numba.njit("void()", cache=True)
def solve():
array = np.random.rand(4, 5)
array[array < 0.5] = 0
エラーメッセージ
Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 2d, C), array(bool, 2d, C), Literal[int](0))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
In definition 9:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of setitem at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)
File "untitled.py", line 7:
def solve():
<source elided>
array = np.random.rand(4, 5)
array[array < 0.5] = 0
^
対処
numpy.where を使う
@numba.njit("void()", cache=True)
def solve():
array = np.random.rand(4, 5)
array = np.where(array < 0.5, 0, array)
ndarray の None による次元の追加
できない
エラーになるコード
@numba.njit("void()", cache=True)
def solve():
a = np.random.rand(4, 5)
a = a[:, None, :]
assert a.shape == (4, 1, 5)
エラーメッセージ
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, C), Tuple(slice<a:b>, none, slice<a:b>))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
All templates rejected with literals.
In definition 9:
All templates rejected without literals.
In definition 10:
All templates rejected with literals.
In definition 11:
All templates rejected without literals.
In definition 12:
TypeError: unsupported array index type none in Tuple(slice<a:b>, none, slice<a:b>)
raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
In definition 13:
TypeError: unsupported array index type none in Tuple(slice<a:b>, none, slice<a:b>)
raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)
[2] During: typing of static-get-item at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (7)
File "untitled.py", line 7:
def solve():
<source elided>
a = np.random.rand(4, 5)
a = a[:, None, :] # コンパイルエラー
^
対処
reshape か expand_dims を使う
ただし expand_dims の第 2 引数に tuple は使えない
@numba.njit("void()", cache=True)
def solve():
a = np.random.rand(4, 5)
a = np.expand_dims(a, 1)
assert a.shape == (4, 1, 5)
int.bit_length
使えない
エラーになるコード
@numba.njit("void()", cache=True)
def solve():
b = (998244353).bit_length()
エラーメッセージ
Unknown attribute 'bit_length' of type Literal[int](998244353)
File "untitled.py", line 6:
def solve():
b = (998244353).bit_length()
^
[1] During: typing of get attribute at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (6)
File "untitled.py", line 6:
def solve():
b = (998244353).bit_length()
^
対処
np.log2 とかでなんとかする(249-1 以上は誤差に注意、あと 0 も正しく動かない)
@numba.njit("void()", cache=True)
def solve():
b = int(np.log2(998244353)) + 1
print(b)
collections
ほぼ使えない
対処
defaultdict -> 値があるか自力で確認する
deque -> リングバッファみたいなのを適当に実装する
Counter -> 自力で数える
itertools
使えない
対処
itertools を使う部分はコンパイルしないように切り分けるか、あらかじめ代わりになりそうなものを用意しておく?
@numba.jit("i8[:,:](i8[:],i8)", cache=True)
def combinaions(arr, r):
n = len(arr)
assert 0 <= r <= n
res_length = 1
for i in range(r):
res_length = res_length * (n-i) // (1+i)
res = np.empty((res_length, r), dtype=arr.dtype)
idxs_arr = np.arange(r)
for idx_res in range(res_length):
res[idx_res] = arr[idxs_arr]
i = 1
while idxs_arr[r-i] == n-i:
i += 1
idxs_arr[r-i] += 1
for j in range(r-i+1, r):
idxs_arr[j] = idxs_arr[j-1] + 1
return res
↑の実装はジェネレータじゃないので少し探索してやめるような場合には効率が悪くなってしまう
これを嫌うなら C++ の next_permutation みたいなのを用意しておくと汎用性も高くて良さそう
string が返る関数
str とか bin とか format とか "%d" % 42 とかは使えない
エラーになるコード
@numba.njit("void()", cache=True)
def solve():
popcnt = bin(4047).count("1")
エラーメッセージ
Untyped global name 'bin': cannot determine Numba type of <class 'builtin_function_or_method'>
File "untitled.py", line 6:
def solve():
popcnt = bin(4047).count("1") # コンパイルエラー
^
対処
Numba で文字列を扱おうとしない
popcount についてはあらかじめ用意しておく(参考:
Python 3でpopcountを計算する - にせねこメモ)
@numba.njit("u8(u8)", cache=True)
def popcount(n):
n = (n & 0x5555555555555555) + (n>>1 & 0x5555555555555555)
n = (n & 0x3333333333333333) + (n>>2 & 0x3333333333333333)
n = (n & 0x0f0f0f0f0f0f0f0f) + (n>>4 & 0x0f0f0f0f0f0f0f0f)
n = (n & 0x00ff00ff00ff00ff) + (n>>8 & 0x00ff00ff00ff00ff)
n = (n & 0x0000ffff0000ffff) + (n>>16 & 0x0000ffff0000ffff)
n = (n & 0x00000000ffffffff) + (n>>32 & 0x00000000ffffffff)
return n
@numba.njit("void()", cache=True)
def solve():
popcnt = popcount(4047)
関数外の変数の書き換え
array = np.array([1, 2, 3])
@numba.njit("void()", cache=True)
def solve():
array[0] = 4
エラーメッセージ
Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (readonly array(int32, 1d, C), Literal[int](0), Literal[int](4))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
TypeError: Cannot modify value of type readonly array(int32, 1d, C)
raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:179
In definition 9:
TypeError: Cannot modify value of type readonly array(int32, 1d, C)
raised from C:\Users\nagiss\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:179
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of staticsetitem at C:/Users/nagiss/PycharmProjects/untitled/untitled.py (8)
File "untitled.py", line 8:
def solve():
array[0] = 4 # コンパイルエラー
^
対処
引数で渡す
@numba.njit("void(i4[:])", cache=True)
def solve(array):
array[0] = 4
標準入力
できないので諦める
関数を返す関数
できないので諦める
関数内の関数の再帰
できないので諦める
他色々
できないと思って諦める
まとめ
Numba はいいぞ