NumPyのflatten
関数は、多次元配列を1次元に変換する関数です。関数型プログラミングに慣れている方は、flatten
という名前の関数でネストされたリストが1次元になるので、どのような動作をするのか想像しやすいかもしれません。
本記事では、NumPyのflatten
関数についての使い方を確かめながら、パフォーマンスについても言及します。実装のパフォーマンスが段違いに上がることを紹介するので、知っておくと役に立つので抑えておくとよいでしょう。
ndarray.flatten
まずは、flatten
関数についてのAPIドキュメントについて見ましょう。
np.ndarray.flatten(order = ‘C’)
params:
パラメータ名 | 型 | 概要 |
---|---|---|
order |
{‘C’,’F’,’A’,’K’}のいずれか | (省略可能)初期値’C’ 配列のデータの並べ方を指定します。 |
returns:
元の配列を1次元配列に直した配列のコピーを返します。
この関数は基本的には引数を指定しないで使う場合が多いでしょう。order
はFortranのような順序の方法を指定する場合に使われるものであまり使用されません。
reshape
より汎用性に欠けますが、引数を特に指定する必要がなく、flatten
を使用することでどのような変換をしたかが一目でわかるので、1次元配列に変換するときはnp.ndarray.flatten
関数を使うことをおすすめします。
2次元配列を1次元配列に変換
コードの使用例を見てみます。
In [1]: import numpy as np
In [2]: a = np.arange(10).reshape(2,5) # 2×5の2次元配列を生成
In [3]: a
Out[3]:
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
In [4]: b = a.flatten() # 1次元配列に変形したものをbに代入
In [5]: a # a自体は変化していない
Out[5]:
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
In [6]: b # bには変形されたものが代入されている
Out[6]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
In [7]: a.shape # shapeも確かめておく
Out[7]: (2, 5)
In [8]: b.shape # bのshapeは1次元配列なので1つの数字しか表示されない
Out[8]: (10,)
3次元配列を1次元配列に変換
元の多次元配列が3次元配列の場合も確認してみます。
In [9]: c = np.arange(12).reshape(2,2,3) # 3次元配列でも確かめる
In [10]: c
Out[10]:
array([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
In [11]: d = c.flatten() # dに1次元配列に変形したものを代入
In [12]: c
Out[12]:
array([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
In [13]: d
Out[13]: array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
In [14]: c.shape # cのshapeとdのshapeをこちらでも確認しておく
Out[14]: (2, 2, 3)
In [15]: d.shape
Out[15]: (12,)
パフォーマンス
同様の動作をする関数にnp.ravel
関数があります。この関数は、flatten
関数とは違い、コピーを作成しません。大きなデータで破壊的な変更をしても問題ない場合は、こちらの関数を使用することでパフォーマンスの向上が見込めます。
多次元配列を作って、パフォーマンスを確認してみます。
In [1]: import numpy as np
In [2]: arr = np.repeat(5, 10000).reshape(250, 40)
In [3]: %timeit arr.flatten()
The slowest run took 40.41 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3.79 µs per loop
In [4]: %timeit np.ravel(arr)
The slowest run took 11.91 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 1.26 µs per loop
np.ravel
の方が、flatten
関数よりも高速に動作していることが確認できました。NumPyでは、使用方法を考えることでパフォーマンスを向上することが出来ます。
np.ravel
関数については、以下の記事で詳細に解説しています。
flattenよりも高速に配列を一次元化するnumpy.ravel関数の使い方 /features/numpy-ravel.html