線形代数の基礎としてまず学習する内容の1つとして逆行列があります。逆行列は機械学習の分野で扱うのはもちろんのこと、他の分野でも頻繁に使われる概念です。

NumPyにも逆行列を求める関数が実装されています。

今回は逆行列についてのおさらいと、その逆行列を求める関数であるlinalg.inv関数についてまとめました。

逆行列

逆行列というのは行列A \\\\ があったとき

A^{-1}

と表されます。この関数は別の表現で以下のように表すこともできます。

AA^{-1} = I

このとき I は単位行列です。元の行列に掛け合わせると単位行列になるような行列をその行列の逆行列と呼びます。

逆行列の求め方

まずは2×2の正方行列についてみていきます。

A\ \ = \ \left( \begin{array}{cc} a_{11} & a_{12} \\ a_{21} & a_{22} \\ \end{array} \right )

このような行列A \\\\ があるとすると、

A^{-1}\ \ = \ \frac{1}{detA} \left ( \begin{array}{cc} a_{22} & -a_{21} \\ -a_{12} & a_{11} \\ \end{array} \right )

が逆行列となります。対角成分を入れ替えて、右上と左下に-1をかければ逆行列ができます。

3×3以上の行列の逆行列の求め方は大きく分けて2つ存在します。

余因子を用いる

行列A \\\\ の余因子をB_{ij} \\\\ とすると、 B_{ij} \\\\ は、i行目とj列目をA \\\\ から取り除いてできた新しい行列\tilde{B} \\\\ の 行列式det\tilde{B} \\\\ となります。

この余因子を要素とする行列に対して、もとの行列式で割ったものが逆行列となります。

すなわち、

A^{-1}_{ij}\ \ =\ \frac{B_{ij}}{detA}

と表すことができます。3×3の行列の逆行列を求めるなら、合計10回行列式を求める必要があるということです。

行基本変形を用いる(掃き出し法)

この方法は、(A\ \ I) \\\\ と単位行列を横に並べた新たな行列を作り、A \\\\ が単位行列となるように行基本変形を行っていきます。 最終的なI \\\\ の部分の形がA^{-1} \\\\ となります。

行基本変形というのは

  • 特定の行を低数倍する
  • ある行とある行とを入れ替える
  • 1つの行に対して他の行の定数倍したものを足す

という3つの操作のことをいいます。

逆行列について詳しい解説が知りたいという方は以下のサイトを参照してみて下さい。

逆行列を求める2鳥の方法と例題 | 高校数学の美しい物語
【行列式編】逆行列の求め方を画像付きで解説! –おぐえもん.com

np.linalg.inv()

逆行列の概要がわかったところで、実際に使っていきましょう。まずはnp.linalg.inv関数です。この関数のAPIドキュメントは以下のとおりです。

numpy.linalg.inv(a)

args:

パラメータ名 概要
a (…,M,M)の形状をもつ配列 逆行列を求めたい行列をここに指定します。

returns:

指定された行列の逆行列 (shape=(…,M,M)) が返されます。

引数として指定すべきなのは1つだけなので使い方はとてもシンプルです。1つ注意すべきなのは、この関数では正方行列の逆行列しか求められないということです。

サンプルコード

それでは実際に扱っていきましょう。
2×2の行列から求めてみましょう。

In [1]: import numpy as np

In [2]: a = np.random.randint(-9, 10, size=(2 ,2)) # まずは2×2の行列から

In [3]: a
Out[3]:
array([[-4,  2],
       [ 7,  2]])

In [4]: np.linalg.inv(a) # 逆行列を求める。
Out[4]:
array([[-0.09090909,  0.09090909],
       [ 0.31818182,  0.18181818]])

In [7]: np.dot(a, np.linalg.inv(a)) # 積をとって単位行列となるか確かめてみる。
Out[7]:
array([[  1.00000000e+00,  -5.55111512e-17],
       [  1.11022302e-16,   1.00000000e+00]])

次は3×3の行列を見ましょう。

In [8]: b = np.random.randint(-10, 10, size=(3,3)) # 次は3×3の行列。

In [9]: b
Out[9]:
array([[ 8, -6,  6],
       [-7, -1, -8],
       [-7,  1,  3]])

In [10]: c = np.linalg.inv(b)

In [11]: c
Out[11]:
array([[-0.00988142, -0.04743083, -0.10671937],
       [-0.15217391, -0.13043478, -0.04347826],
       [ 0.02766798, -0.06719368,  0.09881423]])

In [12]: np.dot(b,c) # 積を
Out[12]:
array([[  1.00000000e+00,   0.00000000e+00,   0.00000000e+00],
       [  1.11022302e-16,   1.00000000e+00,   0.00000000e+00],
       [ -5.55111512e-17,   1.11022302e-16,   1.00000000e+00]])

In [13]: np.dot(c,b) # 逆順にしても結果はほとんど変わらない。
Out[13]:
array([[  1.00000000e+00,   1.38777878e-17,   1.66533454e-16],
       [  8.32667268e-17,   1.00000000e+00,   1.38777878e-16],
       [  0.00000000e+00,   5.55111512e-17,   1.00000000e+00]])

複数の行列の逆行列をまとめて計算してみましょう。
4つの3×3行列の逆行列をまとめて求めてみます。

In [14]: d = np.random.randint(-10, 10, size=(4,3,3)) # 4つの3×3行列

In [15]: d
Out[15]:
array([[[  1,   1,  -7],
        [  1,  -8,  -9],
        [  2, -10,  -3]],

       [[  2,   3,  -9],
        [  7,   7,   1],
        [ -5,   8,   7]],

       [[ -8,   5,   5],
        [ -7,   4,  -8],
        [ -1,   4,   4]],

       [[  7,   0,   8],
        [  1,  -5,   9],
        [ -3,   4,  -9]]])

In [16]: e = np.linalg.inv(d) # 逆行列を求める。

In [17]: e
Out[17]:
array([[[ 0.53658537, -0.59349593,  0.52845528],
        [ 0.12195122, -0.08943089, -0.01626016],
        [-0.04878049, -0.09756098,  0.07317073]],

       [[-0.04560623,  0.10344828, -0.07341491],
        [ 0.06006674,  0.03448276,  0.07230256],
        [-0.10122358,  0.03448276,  0.00778643]],

       [[-0.14814815, -0.        ,  0.18518519],
        [-0.11111111,  0.08333333,  0.30555556],
        [ 0.07407407, -0.08333333, -0.00925926]],

       [[-0.36      , -1.28      , -1.6       ],
        [ 0.72      ,  1.56      ,  2.2       ],
        [ 0.44      ,  1.12      ,  1.4       ]]])

In [18]: np.dot(d,e) # 積をとってみる。
Out[18]:
array([[[[  1.00000000e+00,   0.00000000e+00,   0.00000000e+00],
         [  7.23025584e-01,  -1.03448276e-01,  -5.56173526e-02],
         [ -7.77777778e-01,   6.66666667e-01,   5.55555556e-01],
         [ -2.72000000e+00,  -7.56000000e+00,  -9.20000000e+00]],

        [[  5.55111512e-17,   1.00000000e+00,  -2.22044605e-16],
         [  3.84872080e-01,  -4.82758621e-01,  -7.21913237e-01],
         [  7.40740741e-02,   8.33333333e-02,  -2.17592593e+00],
         [ -1.00800000e+01,  -2.38400000e+01,  -3.18000000e+01]],

        [[  1.11022302e-16,   0.00000000e+00,   1.00000000e+00],
         [ -3.88209121e-01,  -2.41379310e-01,  -8.93214683e-01],
         [  5.92592593e-01,  -5.83333333e-01,  -2.65740741e+00],
         [ -9.24000000e+00,  -2.15200000e+01,  -2.94000000e+01]]],


       [[[  1.87804878e+00,  -5.77235772e-01,   3.49593496e-01],
         [  1.00000000e+00,   0.00000000e+00,  -1.38777878e-17],
         [ -1.29629630e+00,   1.00000000e+00,   1.37037037e+00],
         [ -2.52000000e+00,  -7.96000000e+00,  -9.20000000e+00]],

        [[  4.56097561e+00,  -4.87804878e+00,   3.65853659e+00],
         [  0.00000000e+00,   1.00000000e+00,  -7.80625564e-18],
         [ -1.74074074e+00,   5.00000000e-01,   3.42592593e+00],
         [  2.96000000e+00,   3.08000000e+00,   5.60000000e+00]],

        [[ -2.04878049e+00,   1.56910569e+00,  -2.26016260e+00],
         [  1.11022302e-16,  -1.11022302e-16,   1.00000000e+00],
         [  3.70370370e-01,   8.33333333e-02,   1.45370370e+00],
         [  1.06400000e+01,   2.67200000e+01,   3.54000000e+01]]],


       [[[ -3.92682927e+00,   3.81300813e+00,  -3.94308943e+00],
         [  1.59065628e-01,  -4.82758621e-01,   9.87764182e-01],
         [  1.00000000e+00,   0.00000000e+00,  -1.38777878e-17],
         [  8.68000000e+00,   2.36400000e+01,   3.08000000e+01]],

        [[ -2.87804878e+00,   4.57723577e+00,  -4.34959350e+00],
         [  1.36929922e+00,  -8.62068966e-01,   7.40823137e-01],
         [  0.00000000e+00,   1.00000000e+00,  -1.11022302e-16],
         [  1.88000000e+00,   6.24000000e+00,   8.80000000e+00]],

        [[ -2.43902439e-01,  -1.54471545e-01,  -3.00813008e-01],
         [ -1.19021135e-01,   1.72413793e-01,   3.93770857e-01],
         [  0.00000000e+00,   0.00000000e+00,   1.00000000e+00],
         [  5.00000000e+00,   1.20000000e+01,   1.60000000e+01]]],


       [[[  3.36585366e+00,  -4.93495935e+00,   4.28455285e+00],
         [ -1.12903226e+00,   1.00000000e+00,  -4.51612903e-01],
         [ -4.44444444e-01,  -6.66666667e-01,   1.22222222e+00],
         [  1.00000000e+00,   0.00000000e+00,   1.77635684e-15]],

        [[ -5.12195122e-01,  -1.02439024e+00,   1.26829268e+00],
         [ -1.25695217e+00,   2.41379310e-01,  -3.64849833e-01],
         [  1.07407407e+00,  -1.16666667e+00,  -1.42592593e+00],
         [  4.44089210e-16,   1.00000000e+00,  -1.77635684e-15]],

        [[ -6.82926829e-01,   2.30081301e+00,  -2.30894309e+00],
         [  1.28809789e+00,  -4.82758621e-01,   4.39377086e-01],
         [ -6.66666667e-01,   1.08333333e+00,   7.50000000e-01],
         [ -4.44089210e-16,   1.77635684e-15,   1.00000000e+00]]]])

参考