TensorFlowのDataset APIは、バージョン1.2から追加された新しい機能です。Dataset APIを使うことで、TensorFlowの独自のキューによる入力パイプラインの煩わしさを減らし、データセットの加工や入れ替えがスムーズに出来るようになります。

一般的に、データの読み込みではtf.train.string_input_producertf.train.shuffle_batchなどのTensorFlowの提供するAPIを通して大規模なデータを複数スレッドで処理できるように設計されています。[1] しかしながら、もっと気軽にキューの入出力を意識せずに使いたいと考えていた人は多いはずです。

また、訓練データやテストデータの切り替えも処理パイプラインの組み換えに独自の実装を施さなければならず、簡単とは言い難いものでした。

今回の変更点のDataset APIを使うことで、このような煩わしさから解放されます。独自のキューパイプラインから解放され、複雑な前処理も簡単に実装することができるようになります。本記事では、Dataset APIを紹介し、使い方を解説します。

Dataset API

Dataset APIは、tf.contrib.dataモジュールを経由して使用します。以下の3つのメリットをデベロッパーが得られるように設計されています。

  • 複雑な入力パイプラインを簡単にすること
  • 処理パイプラインの再利用性を高めること
  • 大量のデータを処理可能にしながら、様々な形式に対応すること

Dataset APIの2つの概念

Dataset APIには2つの重要な概念があります。DatasetIteratorです。

Dataset

Datasetは、テンソル要素の集合です。入力データセットを抽象化していて、Datasetの中には複数の入力データが含まれています。C#のCollectionやJavaのStream API、Swiftの配列のようにmapflat_mapzipといったデータ変換APIを適用することができます。

Iterator

Iteratorは、Dataset要素を抽出して機械学習モデルにデータを流す繋ぎのインターフェースです。IteratorDataset内での現在位置を管理して、次の要素を取得するための操作を提供します。

3種類のIteratorの使い分け方

Iteratorには、3種類のIteratorがあります。それぞれ特色が違うので、特徴と使い分け方を紹介します。

one-shot

one-shotは、入力データを一巡する標準的なイテレータです。後述する種類のイテレータと違い、パラメータ化をサポートしません。以下のコードのように、入力データを一巡します。iterator.get_next()関数で、次の要素を取得できるオペレータを使ってsess.runすると次の要素が手に入ります。

import tensorflow as tf
dataset = tf.contrib.data.Dataset.range(10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.InteractiveSession()
for i in range(10):
    print(sess.run(next_element))

このコードをone_shot.pyと名前を付けて実行してみます。すると、0から9までの要素をone_shotイテレータで取得することができました。

$ python one_shot.py
0
1
2
3
4
5
6
7
8
9

rangeで0~9までのDatasetを作成していますが、10を超えてnext_elementを取得しようとすると、どうなるのでしょうか。one-shotイテレータの場合は、tf.errors.OutOfRangeError例外をスローします。

initializable

initializableイテレータは、データセットの作成と処理パイプラインをパラメータ化することができます。

initializableイテレータを作成する場合は、make_one_shot_iteratorの代わりに、make_initializable_iteratorメソッドを使用します。

初期化するパラメータは以下のように、iterator.initializerオペレータを使ってtf.Sessionのfeed_dictに値を指定します。

import tensorflow as tf
range_param = tf.placeholder(tf.int64)
dataset = tf.contrib.data.Dataset.range(range_param)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.InteractiveSession()
sess.run(iterator.initializer, feed_dict={range_param: 10})
for i in range(10):
    print(sess.run(next_element))

sess.run(iterator.initializer, feed_dict={range_param: 30})
for i in range(30):
    print(sess.run(next_element))

reinitializable

reinitializableイテレータは、同じ型と形状を出力するデータセットを組み替えることができます。つまり、訓練データセットやテストデータセットなどのデータセットの切り替えのために使用します。

仮にFizzBuzzの関数を学習させることを考えてみましょう。0~99までの数字を訓練データセットとして、100から200までのテストデータセットを作成して学習させることを考えてみます。この場合、以下のコードのような変換をすると、訓練データとテストデータをreinitializableイテレータを使用して分離することができます。

import tensorflow as tf

sess = tf.InteractiveSession()

def encode_fizz_buzz(x):
    return tf.case({
        tf.equal(x % 3, 0): lambda: (x, tf.constant(0)),
        tf.equal(x % 5, 0): lambda: (x, tf.constant(1)),
        tf.equal(x % 15, 0): lambda: (x, tf.constant(2))},
        default=lambda: (x, tf.constant(3)))

train_dataset = tf.contrib.data.Dataset.range(100).map(encode_fizz_buzz)
test_dataset = tf.contrib.data.Dataset.range(100,
        200).map(encode_fizz_buzz)
iterator = tf.contrib.data.Iterator.from_structure(train_dataset.output_types,
        train_dataset.output_shapes)
next_element = iterator.get_next()

train_init_op = iterator.make_initializer(train_dataset)
test_init_op = iterator.make_initializer(test_dataset)

for i in range(20): # 20エポック学習させる
    sess.run(train_init_op)
    for i in range(100):
        train_X, train_Y = sess.run(next_element)
        print(train_X, train_Y)
        # 訓練データの学習

    sess.run(test_init_op)
    for i in range(100):
        test_X, test_Y = sess.run(next_element)
        print(test_X, test_Y)
        # テストデータのloss, accuracyの確認

reinitializableイテレータは、型と形状を指定して、make_initializer関数で初期化オペレーションを生成します。

DatasetのTransformerの種類と使い方

Datasetは、関数合成のように処理パイプラインを組み合わせて使用します。そのため、複数の変換関数を覚える必要があります。基本的な使用方法を紹介します。

基本変換

map

mapは、各入力要素を関数適用するために使用します。上述の例だと、rangeで生成した数値をfizzbuzzの学習データに変換するために使用しています。

train_dataset = Dataset.range(100).map(encode_fizz_buzz)

以下のコードのように、lambda式を使用することもできます。

dataset = Dataset.range(100).map(lambda x: tf.square(x))

flat_map

flat_mapは、ネストしたDatasetの出力をネストを解消しながら各要素に関数適用することができます。

Dataset.range(100).flat_map(lambda x: Dataset.range(x, x+2))
# 0, 1, 1, 2, 2, 3, 3, 4....

zip

zipはPythonのビルトイン関数のzipと同様に、複数のDatasetを1つにまとめることができます。

num = Dataset.range(100)
encode = num.map(lambda x: tf.case({
        tf.equal(x % 3, 0): lambda: tf.constant(0),
        tf.equal(x % 5, 0): lambda: tf.constant(1),
        tf.equal(x % 15, 0): lambda: tf.constant(2)},
        default=lambda: tf.constant(3)))
dataset = Dataset.zip((num, encode))
# (0, 0), (1, 3), (2, 3), (3, 0), (4, 3)...

group_by_window

group_by_windowは、条件でグループ化して窓サイズで分割するために使用します。

Dataset.range(100).group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)
# [0, 2, 4, 6], [1, 3, 5, 7], [8, 10, 12, 14], [9, 11, 13, 15]...

データセット作成便利関数

repeat

repeatは、データセットの要素を繰り返すことができます。データセットを作成する際には、one-shotイテレータを使う場合にはエポック数を指定することがあります。引数を指定しないと、永久にリピートします。

train_dataset = Dataset.range(100).map(encode_fizz_buzz).repeat(num_epochs)

shuffle

shuffleは、データセットの要素の順番をシャッフルするために使用します。バッチ毎のデータ分布が偏ると上手く学習できなくなることがあるので、シャッフル化しておくと良いでしょう。引数にはシャッフルするバッファサイズを指定します。

Dataset.range(100).map(lambda x: x * 2).shuffle(20)
# 26, 32, 10, 44, 6, 34...

batch

batchはバッチサイズ毎に分割するために使用します。

Dataset.range(100).map(lambda x: x * 2).shuffle(20).batch(4)
# [38,  4,  0, 32], [14, 10, 26, 34], [46, 20, 54, 44], [22, 36,  6, 64]...

unbatch

unbatchはバッチを分解することができます。

Dataset.range(100).batch(4).unbatch()
# 1, 2, 3, 4, 5, 6, 7,...

padded_batch

padded_batchはゼロパディングしながら、分割します。

Dataset.range(100).batch(4).padded_batch(4, [5])
# array([[ 0,  1,  2,  3,  0],
#        [ 4,  5,  6,  7,  0],
#        [ 8,  9, 10, 11,  0],
#        [12, 13, 14, 15,  0]]),
# array([[16, 17, 18, 19,  0],
#        [20, 21, 22, 23,  0],
#        [24, 25, 26, 27,  0],
#        [28, 29, 30, 31,  0]]),
# array([[32, 33, 34, 35,  0],
#        [36, 37, 38, 39,  0],
#        [40, 41, 42, 43,  0],
#        [44, 45, 46, 47,  0]]),...

生成関数

range

rangeはPythonのビルトイン関数と同様に、範囲内のイテレータを作成するために使用します。

Dataset.range(100)
# 1, 2, 3, 4, 5,...,99

enumerate

enumerateもPythonのビルトイン関数と同様に、インデックスの番号を付けて返します。

Dataset.range(100, 200).enumerate()
# (0, 100), (1, 101), (2, 102), (3, 103), (4, 104), (5, 105)

知っておくと役に立つDataset APIの使用方法

実践でDataset APIを使いこなすためには、各ファイルフォーマットからの読み込み方や加工方法を知る必要があります。CSVファイルやTFRecordの読み込み方やEstimatorとの併用を含めて使い方を紹介します。

CSVファイルから前処理をする

以下のCSVファイルから画像ファイルの名前とラベルを取得して、カテゴリ分類する例を考えてみましょう。最初の一行目に各カラムの説明が書いてあり、次の行からデータの中身が入力されています。

filename,label
001.jpg,rock
002.jpg,river
003.jpg,rock
003.jpg,sky
004.jpg,person
005.jpg,sky
006.jpg,animal
007.jpg,river
008.jpg,rock
009.jpg,rock
010.jpg,sky

CSVファイルを読むには、CSVファイルが小さなデータセットであればメモリに入れてfrom_tensorsfrom_tensor_slicesしても問題ないですが、大きい場合にも対応可能なようにdata.TextLineDatasetを使用します。

import tensorflow as tf
import tensorflow.contrib.data as data

sess = tf.InteractiveSession()

categories = ['rock', 'river', 'sky', 'person', 'animal']

def to_index(label):
    return categories.index(label)

def parse_csv(line):
    [filename, category] = line.decode('utf-8').split(',')
    return filename.encode('utf-8'), to_index(category)

def read_image(filename, label):
    contents = tf.read_file(filename)
    return tf.image.decode_image(contents), label

dataset = data.TextLineDataset("/tmp/test.csv")\
        .skip(1)\
        .map(lambda x: tf.py_func(parse_csv, [x], [tf.string, tf.int64]))\
        .map(read_image)\
        .shuffle(4)\
        .batch(4)

iterator = dataset.make_one_shot_iterator()
next_elem = iterator.get_next()

print(sess.run(next_elem))

tf.py_funcを使用することで、上記のようにPythonの関数を呼び出すことができるようになります。

TFRecordを読み込む実用例

TFRecordを作成すると、GPUとCPUを並列に動作させることができるので高速になります。TFRecordもtf.contrib.dataモジュールの関数で読み込むことができます。以下のようにパース関数を書いてmapで変換します。

import tensorflow as tf
import tensorflow.contrib.data as data

sess = tf.InteractiveSession()

def parse_function(example_proto):
    features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
            "label": tf.FixedLenFeature((), tf.int64, default_value=0)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['image'], parsed_features['label']

def read_image(image_raw, label):
    image = tf.decode_raw(image_raw, tf.uint8)
    return image, label

dataset = data.TFRecordDataset("./test.tfrecords")\
        .map(parse_function)\
        .map(read_image)\
        .shuffle(4)\
        .batch(4)

iterator = dataset.make_one_shot_iterator()
next_elem = iterator.get_next()

print(sess.run(next_elem))

まとめ

TensorFlowのDataset APIはデータセットを簡単に読み込めるようにした新しいモジュールです。これまでの読み込み関数は廃止になる可能性もありますし、このモジュールの関数が大幅に変更になるかもしれません。

個人的にはこれまでのAPIと比較すると、使いやすい印象があります。

さらに使いやすく、ハイパフォーマンスになると良いですね。

参考

[1] Reading data
[2] Using the Dataset API for TensorFlow Input Pipelines