NumPyのndarrayは、np.where
関数に条件式を指定することで、目的の要素のインデックスを取得することができます。
ヒストグラムのインデックスを取得したいときや、しきい値を設けて値を制限したいときなどに便利なので、覚えておくと役に立つはずです。
np.where
np.where
は、条件を満たす要素のインデックスを返す関数です。APIドキュメントは以下のようになっています。
numpy.where(condition[, x, y])
params:
パラメータ名 | 型 | 概要 |
---|---|---|
condition |
array_like (配列に相当するもの) もしくは bool値 |
条件もしくはbool値を指定します。 |
x, y |
array_like (配列に相当するもの) |
(省略可能)condition で指定された条件に対してTrueならx を、Falseならy を返します。x, y のshape は元の配列と揃えるようにしましょう。(これは任意入力ですがどちらも指定する必要があります) |
returns:
抽出されたndarrayの要素のindexが返されます。元のndarrayが二次元のときは次元ごとのindexが記された一次元配列が2つ返されます。
x, y
が指定されていれば、要素がx
またはy
に変換されたndarrayが返されます。
arr[a < 10]
のように、インデックス部分に条件を指定することで、目的の要素を取得することができました。np.where
を使うことで、値ではなくインデックスを取得することができます。
条件の指定
基本的な使い方は、第一引数に条件のみを指定する方法です。以下のように第一引数に条件式を指定することで、条件を満たす要素のインデックスを取得することができます。
In [1]: import numpy as np
In [2]: a = np.arange(20, 0, -2) # まずは1次元配列を生成
In [3]: a
Out[3]: array([20, 18, 16, 14, 12, 10, 8, 6, 4, 2])
In [4]: np.where(a < 10) # 10未満のindexを取得
Out[4]: (array([6, 7, 8, 9]),)
In [5]: a[np.where(a < 10)]
Out[5]: array([8, 6, 4, 2]) # 10未満の要素だけのindexとなっていることが確認できます。
続いて多次元配列でも試してみます。
In [1]: import numpy as np
In [2]: a = np.arange(12).reshape((3, 4)) # 3×4の二次元配列にする
In [3]: a
Out[3]:
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
In [4]: np.where(a % 2 == 0) # 偶数だけ取り出してみる。
Out[4]: (array([0, 0, 1, 1, 2, 2]), array([0, 2, 0, 2, 0, 2]))
一瞬困惑するかもしれませんが、同様にインデックスが取得されています。上記の例では、行と列のインデックスが取り出されており、対応する(0, 0)や(0, 2)は偶数になっていることが分かります。
np.whereを使った三項演算子
np.where
はプログラムの三項演算子のような使い方ができます。
第一引数に抽出したい要素の条件を指定し、第二引数以降で条件を満たすとき値と満たさないときの値を指定することができます。
x, y
を上手く使うことで、配列の一部の要素だけ変換したいときに便利になります。
条件のあとに条件を満たす場合と満たさない場合にどのような値を返すか指定しておけば、その値を要素とする新しい配列を返します。
In [5]: np.where( a%2 == 0, 'even', 'odd') # 偶数ならeven,奇数ならoddと返す。
Out[5]:
array(['even', 'odd', 'even', 'odd', 'even', 'odd', 'even', 'odd', 'even',
'odd', 'even', 'odd'],
dtype='<U4')
In [6]: np.where( a%2 == 0, 'even') # Trueの時だけ値を設定するとエラーが返ってくる。
---------------------------------------------------------------------------
(エラーメッセージが表示される)
ValueError: either both or neither of x and y should be given
In [6]: np.where( a%2 == 0, 'even', 'odd') # 偶数ならeven,奇数ならoddと返す。
In [7]: b = np.reshape(a, (3, 4))
In [8]: c = b ** 2
In [9]: c
Out[9]:
array([[ 0, 1, 4, 9],
[ 16, 25, 36, 49],
[ 64, 81, 100, 121]])
In [10]: np.where(b % 2 == 0, b, c) # 奇数のところだけcの要素に取り替える。
Out[10]:
array([[ 0, 1, 2, 9],
[ 4, 25, 6, 49],
[ 8, 81, 10, 121]])
最後にbroadcastingを紹介します。最後の引数に、配列やタプルのようなイテレーション可能な値を指定すると、繰り返したときにインデックスアクセスしたときの値が使用されます。
In [14]: np.where( b%2 == 0, b, (10, 8, 6, 4)) # broadcastingが適用され、(10, 8, 6, 4)が繰り返されたものが使われている。
Out[14]:
array([[ 0, 8, 2, 4],
[ 4, 8, 6, 4],
[ 8, 8, 10, 4]])