NumPyのflatten関数は、多次元配列を1次元に変換する関数です。関数型プログラミングに慣れている方は、flattenという名前の関数でネストされたリストが1次元になるので、どのような動作をするのか想像しやすいかもしれません。

本記事では、NumPyのflatten関数についての使い方を確かめながら、パフォーマンスについても言及します。実装のパフォーマンスが段違いに上がることを紹介するので、知っておくと役に立つので抑えておくとよいでしょう。

ndarray.flatten

まずは、flatten関数についてのAPIドキュメントについて見ましょう。

np.ndarray.flatten(order = ‘C’)

params:

パラメータ名 概要
order {‘C’,’F’,’A’,’K’}のいずれか (省略可能)初期値’C’
配列のデータの並べ方を指定します。

returns:

元の配列を1次元配列に直した配列のコピーを返します。

この関数は基本的には引数を指定しないで使う場合が多いでしょう。orderはFortranのような順序の方法を指定する場合に使われるものであまり使用されません。

reshapeより汎用性に欠けますが、引数を特に指定する必要がなく、flattenを使用することでどのような変換をしたかが一目でわかるので、1次元配列に変換するときはnp.ndarray.flatten関数を使うことをおすすめします。

2次元配列を1次元配列に変換

コードの使用例を見てみます。

In [1]: import numpy as np

In [2]: a = np.arange(10).reshape(2,5) # 2×5の2次元配列を生成

In [3]: a
Out[3]:
array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]])

In [4]: b = a.flatten() # 1次元配列に変形したものをbに代入

In [5]: a # a自体は変化していない
Out[5]:
array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]])

In [6]: b # bには変形されたものが代入されている
Out[6]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [7]: a.shape # shapeも確かめておく
Out[7]: (2, 5)

In [8]: b.shape # bのshapeは1次元配列なので1つの数字しか表示されない
Out[8]: (10,)

3次元配列を1次元配列に変換

元の多次元配列が3次元配列の場合も確認してみます。

In [9]: c = np.arange(12).reshape(2,2,3) # 3次元配列でも確かめる

In [10]: c
Out[10]:
array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]]])

In [11]: d = c.flatten() # dに1次元配列に変形したものを代入

In [12]: c
Out[12]:
array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]]])

In [13]: d
Out[13]: array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

In [14]: c.shape # cのshapeとdのshapeをこちらでも確認しておく
Out[14]: (2, 2, 3)

In [15]: d.shape
Out[15]: (12,)

パフォーマンス

同様の動作をする関数にnp.ravel関数があります。この関数は、flatten関数とは違い、コピーを作成しません。大きなデータで破壊的な変更をしても問題ない場合は、こちらの関数を使用することでパフォーマンスの向上が見込めます。

多次元配列を作って、パフォーマンスを確認してみます。

In [1]: import numpy as np

In [2]: arr = np.repeat(5, 10000).reshape(250, 40)

In [3]: %timeit arr.flatten()
The slowest run took 40.41 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3.79 µs per loop

In [4]: %timeit np.ravel(arr)
The slowest run took 11.91 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 1.26 µs per loop

np.ravelの方が、flatten関数よりも高速に動作していることが確認できました。NumPyでは、使用方法を考えることでパフォーマンスを向上することが出来ます。

np.ravel関数については、以下の記事で詳細に解説しています。

flattenよりも高速に配列を一次元化するnumpy.ravel関数の使い方 /features/numpy-ravel.html