NumPyのndarrayは、NumPy操作をするための多次元配列です。多次元の構造はndarray.shapeで確認することができます。

多次元構造を操作するために必須となるのが、軸(axis)を正しく理解することです。多次元の配列構造を処理するために、NumPyの関数の引数には、axisを指定することが出来る場面が多々あります。

他の記事において、要素の合計を計算するnp.sum関数、要素の平均を計算するnp.average関数、最大の要素を探すnp.amax関数などを紹介してきましたが、それぞれの関数の中で引数としてaxisがあります。 これは配列の軸にあたるものですがどの軸がどの次元に対応しているのか分かりにくいことが多いです。 そのため、本記事では、

  • 次元数とはなにか
  • 軸(axis)とは何か
  • 関数の引数としてaxisを指定すると何が起きているのか

について解説していきます。

ndarrayの次元数(ndim)とは何か

NumPyの多次元配列であるndarrayは、.shapeでその構造を把握することができます。

In [1]: import numpy as np

In [2]: a = np.array([[1, 2, 3], [4, 5, 6]])

In [3]: a.shape
Out[3]: (2, 3)

上記のコードの場合、shapeを見ると2×3の構造をしていることが分かります。shapeについての詳しい解説は以下の記事を参考にしてください。

NumPyのndarrayのインスタンス変数shapeの意味 /features/numpy-shape.html

ndimは多次元配列が何次元の構造をしているのかを意味しています。つまり、shapeの要素の数なのでlen(arr.shape)ということになります。

In [4]: a.ndim
Out[4]: 2

axisについて

axisは名前の通り、座標軸のようなものです。どの軸かを指定するための方法として、axisshapeのインデックスに対応します。

3×2の行列を考えてみます。

In [1]: import numpy as np

In [2]: a = np.arange(6).reshape((3, 2))

In [3]: a
Out[3]:
array([[0, 1],
       [2, 3],
       [4, 5]])

In [4]: a.shape
Out[4]: (3, 2)

上記のaという多次元配列は3×2の行列なので、shape(3, 2)です。NumPyのndarrayはネストした配列になっていますが、ネストの上位順にshapeの要素は並びます。

NumPyの2D形状

3×2の行列の場合、プリミティブな要素だけ入っている配列は、上図のように列方向になります。そして、より上位の配列なのは行方向です。つまり、軸順は(行方向, 列方向)となります。

3次元に拡張してみましょう。この3×2の行列を複数持つ配列を新しく作ることになるので、新しい軸はshapeの先頭に相当します。分かりやすくするために、作成したaを2つ含む配列を作成してみます。

In [5]: b = np.array([a, a])

In [6]: b.shape
Out[6]: (2, 3, 2)

In [7]: b
Out[7]:
array([[[0, 1],
        [2, 3],
        [4, 5]],

       [[0, 1],
        [2, 3],
        [4, 5]]])

この場合、以下の図のように3×2の多次元配列を持つ配列をつくったので、新しく出来た最上位のaxisが0になります。

NumPyの3D形状

関数の引数としてのaxis

NumPyには、axisを引数にとる関数が少なくありません。ndarray.sumでは、axisを指定して合計を計算することができます。その結果出力されるshapeは指定した軸方向に次元削減されることになります。

上記のb配列を例にすると、


b.shape == (2, 3, 2)

ですが、sum関数の出力shapeは以下のようになります。


b.sum(axis=0).shape == (3, 2)
b.sum(axis=1).shape == (2, 2)
b.sum(axis=2).shape == (2, 3)

図示しながら確認してみましょう。axis=0を引数に取った場合、下図のaxis=0方向の矢印に向かって要素が足し合わされます。

0axis

0軸方向には、同じ値が2つ並んでいるので、各要素が2倍されるはずです。

In [8]: b.sum(axis=0)
Out[8]:
array([[ 0,  2],
       [ 4,  6],
       [ 8, 10]])
In [9]: b.sum(axis=0).shape
Out[9]: (3, 2)

期待通りの結果になりました。axis=1にして確認してみます。この場合は、下図のように行方向に足し合わされることになります。

1axis

どうなるか結果は想像できましたか?実際に確認してみます。

In [10]: b.sum(axis=1)
Out[10]:
array([[6, 9],
       [6, 9]])

In [11]: b.sum(axis=1).shape
Out[11]: (2, 2)

行方向の要素が次元毎に足し合わされて、期待通りの結果になりました。axis=2の場合も確認してみます。下図のように列方向に足し合わされるはずです。

2axis

In [12]: b.sum(axis=2)
Out[12]:
array([[1, 5, 9],
       [1, 5, 9]])

In [13]: b.sum(axis=2).shape
Out[13]: (2, 3)

予想どおりでしたか?列方向に各要素が足し合わされて次元削減されました。

NumPyのaxisshapeのインデックスです。是非NumPyの操作方法を覚えて使いこなしてください。