今回の記事では、前回まで扱った内容を実際にNumPyで実装してみます。
最もシンプルな構造を実装してみることでニューラルネットワークに対する理解を深めていきましょう。
NumPyでの実装
今回は、NumPyを使ってニューラルネットワークで3種類のアヤメ(Iris)の品種を分類してみます。このデータセットは機械学習の基本としてはよく使われるもので、がく片の長さとその幅、花びらの長さをその幅をアヤメの3種類の花(「Iris setona」, 「Iris virginica」,「Iris versicolor」)についてそれぞれ50サンプルずつ単位をcmで計測したものとなっています。
今回はそのうちの2つ「Iris setona」と「Iris virginica」を4つのデータから分類してみましょう。
データセットの用意
まず、以下のリンクからアイリス花データをダウンロードしてください。
アイリス花データ
コマンドラインからは以下のコマンドでダウンロードすることができます。
$ wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
このデータを保存したディレクトリまで移動して、Pythonを起動します。今回は、データ読み込みのためにPandasを利用します。
まずはデータを読み込みます。
この実行結果は以下のようになります。
データの番号 |
がく片の長さ |
がく片の幅 |
花びらの長さ |
花びらの幅 |
種類名 |
の順番でデータが並んでいます。
次にこれらを訓練用のデータとうまく学習できたかを確かめるテスト用のデータとに分けます。今回はそれぞれサンプルが50ずつあるので、40組を訓練用データに、10組をテスト用データとして使います。うまく学習できたかどうかを確かめるためには10組のデータをどれだけ精度良く分類できたか評価します。
このように教師あり学習においては訓練用の集合とテスト用の集合を分けることで、訓練用の集合のみに過剰に学習してしまう過学習が起こっていないのかを確認することができます。
過学習すると、訓練用の集合に対しては良い結果を残すのに、テスト用の集合に対しては、学習モデルを使用すると著しく悪い結果となってしまうことがあります。
それではこれらのデータの中身を確認してみましょう。4つデータがありますが4次元空間にプロットすることはできないのでがく片と花びらの大きさで分けてプロットしてみましょう。
これを行うと、以下のグラフがプロットされます。
これを見るとグラフを見た感じ2つのデータだけでも十分線形分類することが可能そうですね。ですが、今回はニューラルネットワークを使ってみるという意味でこれら4種類のデータを元に分類していきたいと思います。
ニューラルネットワークの構築
では2章で扱ったニューラルネットワークを入力数だけ1つ増やした状態で再現してみたいと思います。これを模式図にすると以下のようになります。
ではこれらを実装していきましょう。
update
関数が2パターンありますが、1つ目のupdate()
の方は2章で求めた偏微分の式をそのまま適用したものとなっています。2つ目のupdate_2()
関数はこのような解析的に値を求めるのではなく、実際に少しだけパラメータ(例えば)の値をズラしたときにどれほど損失関数の値が増減するのかを計算して求めています。忘れがちですが微小変化量h
で割ることを忘れないようにしてください。
また、学習率eta
の設定をする必要があります。このように、学習モデルの手動で設定する必要がある値をハイパーパラメータと呼ぶことがあります。このeta
は2章で扱ったをコードに落としたものなので新しい概念ではありません。
では、これらの関数を定義したところで学習を開始してみます。
これの実行結果は以下のとおりです。
update()
、update_2()
のどちらを用いてもうまく学習ができていることがわかりますね。また、パラメーターの値もそれほど変わっていないようです。
データ解析の中で微分を行うときはたいていupdate_2()
関数のようにあるパラメーターを微小量変化させることでどれだけ値が変化したのかを調べることが多いです。
積分をするなら、このデータの値を単純に足し合わせます。
ただ、今回のモデルでは解析的に微分(特に偏微分)の値を数式で表すことができているのでそれを使って学習させてもみました。
まとめ
今回の章では前回、前々回で扱ったニューラルネットワークの構造を使ってアヤメの2品目のデータを使って分類を行ってみました。
一番シンプルなニューラルネットワークで分類をしてみましたが、ニューラルネットワークの魅力的なところはこのニューロンの数を増やしてさらに層を多くすることでより複雑なデータに対しても学習をすすめることができるという点です。
次回以降ではニューラルネットワークの構造を更に複雑にし、計算量が膨大になるのを防ぐための工夫の1つとして誤差逆伝播法について扱っていきたいと思います。
参考