- Dataset API
- Dataset APIの2つの概念
- 3種類のIteratorの使い分け方
- DatasetのTransformerの種類と使い方
- 知っておくと役に立つDataset APIの使用方法
- まとめ
- 参考
TensorFlowのDataset APIは、バージョン1.2から追加された新しい機能です。Dataset APIを使うことで、TensorFlowの独自のキューによる入力パイプラインの煩わしさを減らし、データセットの加工や入れ替えがスムーズに出来るようになります。
一般的に、データの読み込みではtf.train.string_input_producer
やtf.train.shuffle_batch
などのTensorFlowの提供するAPIを通して大規模なデータを複数スレッドで処理できるように設計されています。[1] しかしながら、もっと気軽にキューの入出力を意識せずに使いたいと考えていた人は多いはずです。
また、訓練データやテストデータの切り替えも処理パイプラインの組み換えに独自の実装を施さなければならず、簡単とは言い難いものでした。
今回の変更点のDataset APIを使うことで、このような煩わしさから解放されます。独自のキューパイプラインから解放され、複雑な前処理も簡単に実装することができるようになります。本記事では、Dataset APIを紹介し、使い方を解説します。
Dataset API
Dataset APIは、tf.contrib.data
モジュールを経由して使用します。以下の3つのメリットをデベロッパーが得られるように設計されています。
- 複雑な入力パイプラインを簡単にすること
- 処理パイプラインの再利用性を高めること
- 大量のデータを処理可能にしながら、様々な形式に対応すること
Dataset APIの2つの概念
Dataset APIには2つの重要な概念があります。Dataset
とIterator
です。
Dataset
Dataset
は、テンソル要素の集合です。入力データセットを抽象化していて、Dataset
の中には複数の入力データが含まれています。C#のCollection
やJavaのStream API、Swiftの配列のようにmap
やflat_map
、zip
といったデータ変換APIを適用することができます。
Iterator
Iterator
は、Dataset
の要素を抽出して機械学習モデルにデータを流す繋ぎのインターフェースです。Iterator
でDataset
内での現在位置を管理して、次の要素を取得するための操作を提供します。
3種類のIteratorの使い分け方
Iterator
には、3種類のIterator
があります。それぞれ特色が違うので、特徴と使い分け方を紹介します。
one-shot
one-shotは、入力データを一巡する標準的なイテレータです。後述する種類のイテレータと違い、パラメータ化をサポートしません。以下のコードのように、入力データを一巡します。iterator.get_next()
関数で、次の要素を取得できるオペレータを使ってsess.run
すると次の要素が手に入ります。
import tensorflow as tf
dataset = tf.contrib.data.Dataset.range(10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.InteractiveSession()
for i in range(10):
print(sess.run(next_element))
このコードをone_shot.pyと名前を付けて実行してみます。すると、0から9までの要素をone_shotイテレータで取得することができました。
$ python one_shot.py
0
1
2
3
4
5
6
7
8
9
rangeで0~9までのDatasetを作成していますが、10を超えてnext_elementを取得しようとすると、どうなるのでしょうか。one-shotイテレータの場合は、tf.errors.OutOfRangeError
例外をスローします。
initializable
initializableイテレータは、データセットの作成と処理パイプラインをパラメータ化することができます。
initializableイテレータを作成する場合は、make_one_shot_iterator
の代わりに、make_initializable_iterator
メソッドを使用します。
初期化するパラメータは以下のように、iterator.initializerオペレータを使ってtf.Sessionのfeed_dictに値を指定します。
import tensorflow as tf
range_param = tf.placeholder(tf.int64)
dataset = tf.contrib.data.Dataset.range(range_param)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.InteractiveSession()
sess.run(iterator.initializer, feed_dict={range_param: 10})
for i in range(10):
print(sess.run(next_element))
sess.run(iterator.initializer, feed_dict={range_param: 30})
for i in range(30):
print(sess.run(next_element))
reinitializable
reinitializableイテレータは、同じ型と形状を出力するデータセットを組み替えることができます。つまり、訓練データセットやテストデータセットなどのデータセットの切り替えのために使用します。
仮にFizzBuzzの関数を学習させることを考えてみましょう。0~99までの数字を訓練データセットとして、100から200までのテストデータセットを作成して学習させることを考えてみます。この場合、以下のコードのような変換をすると、訓練データとテストデータをreinitializableイテレータを使用して分離することができます。
import tensorflow as tf
sess = tf.InteractiveSession()
def encode_fizz_buzz(x):
return tf.case({
tf.equal(x % 3, 0): lambda: (x, tf.constant(0)),
tf.equal(x % 5, 0): lambda: (x, tf.constant(1)),
tf.equal(x % 15, 0): lambda: (x, tf.constant(2))},
default=lambda: (x, tf.constant(3)))
train_dataset = tf.contrib.data.Dataset.range(100).map(encode_fizz_buzz)
test_dataset = tf.contrib.data.Dataset.range(100,
200).map(encode_fizz_buzz)
iterator = tf.contrib.data.Iterator.from_structure(train_dataset.output_types,
train_dataset.output_shapes)
next_element = iterator.get_next()
train_init_op = iterator.make_initializer(train_dataset)
test_init_op = iterator.make_initializer(test_dataset)
for i in range(20): # 20エポック学習させる
sess.run(train_init_op)
for i in range(100):
train_X, train_Y = sess.run(next_element)
print(train_X, train_Y)
# 訓練データの学習
sess.run(test_init_op)
for i in range(100):
test_X, test_Y = sess.run(next_element)
print(test_X, test_Y)
# テストデータのloss, accuracyの確認
reinitializableイテレータは、型と形状を指定して、make_initializer
関数で初期化オペレーションを生成します。
DatasetのTransformerの種類と使い方
Datasetは、関数合成のように処理パイプラインを組み合わせて使用します。そのため、複数の変換関数を覚える必要があります。基本的な使用方法を紹介します。
基本変換
map
mapは、各入力要素を関数適用するために使用します。上述の例だと、range
で生成した数値をfizzbuzzの学習データに変換するために使用しています。
train_dataset = Dataset.range(100).map(encode_fizz_buzz)
以下のコードのように、lambda式を使用することもできます。
dataset = Dataset.range(100).map(lambda x: tf.square(x))
flat_map
flat_mapは、ネストしたDatasetの出力をネストを解消しながら各要素に関数適用することができます。
Dataset.range(100).flat_map(lambda x: Dataset.range(x, x+2))
# 0, 1, 1, 2, 2, 3, 3, 4....
zip
zipはPythonのビルトイン関数のzip
と同様に、複数のDatasetを1つにまとめることができます。
num = Dataset.range(100)
encode = num.map(lambda x: tf.case({
tf.equal(x % 3, 0): lambda: tf.constant(0),
tf.equal(x % 5, 0): lambda: tf.constant(1),
tf.equal(x % 15, 0): lambda: tf.constant(2)},
default=lambda: tf.constant(3)))
dataset = Dataset.zip((num, encode))
# (0, 0), (1, 3), (2, 3), (3, 0), (4, 3)...
group_by_window
group_by_windowは、条件でグループ化して窓サイズで分割するために使用します。
Dataset.range(100).group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)
# [0, 2, 4, 6], [1, 3, 5, 7], [8, 10, 12, 14], [9, 11, 13, 15]...
データセット作成便利関数
repeat
repeatは、データセットの要素を繰り返すことができます。データセットを作成する際には、one-shotイテレータを使う場合にはエポック数を指定することがあります。引数を指定しないと、永久にリピートします。
train_dataset = Dataset.range(100).map(encode_fizz_buzz).repeat(num_epochs)
shuffle
shuffleは、データセットの要素の順番をシャッフルするために使用します。バッチ毎のデータ分布が偏ると上手く学習できなくなることがあるので、シャッフル化しておくと良いでしょう。引数にはシャッフルするバッファサイズを指定します。
Dataset.range(100).map(lambda x: x * 2).shuffle(20)
# 26, 32, 10, 44, 6, 34...
batch
batchはバッチサイズ毎に分割するために使用します。
Dataset.range(100).map(lambda x: x * 2).shuffle(20).batch(4)
# [38, 4, 0, 32], [14, 10, 26, 34], [46, 20, 54, 44], [22, 36, 6, 64]...
unbatch
unbatchはバッチを分解することができます。
Dataset.range(100).batch(4).unbatch()
# 1, 2, 3, 4, 5, 6, 7,...
padded_batch
padded_batchはゼロパディングしながら、分割します。
Dataset.range(100).batch(4).padded_batch(4, [5])
# array([[ 0, 1, 2, 3, 0],
# [ 4, 5, 6, 7, 0],
# [ 8, 9, 10, 11, 0],
# [12, 13, 14, 15, 0]]),
# array([[16, 17, 18, 19, 0],
# [20, 21, 22, 23, 0],
# [24, 25, 26, 27, 0],
# [28, 29, 30, 31, 0]]),
# array([[32, 33, 34, 35, 0],
# [36, 37, 38, 39, 0],
# [40, 41, 42, 43, 0],
# [44, 45, 46, 47, 0]]),...
生成関数
range
rangeはPythonのビルトイン関数と同様に、範囲内のイテレータを作成するために使用します。
Dataset.range(100)
# 1, 2, 3, 4, 5,...,99
enumerate
enumerateもPythonのビルトイン関数と同様に、インデックスの番号を付けて返します。
Dataset.range(100, 200).enumerate()
# (0, 100), (1, 101), (2, 102), (3, 103), (4, 104), (5, 105)
知っておくと役に立つDataset APIの使用方法
実践でDataset APIを使いこなすためには、各ファイルフォーマットからの読み込み方や加工方法を知る必要があります。CSVファイルやTFRecordの読み込み方やEstimatorとの併用を含めて使い方を紹介します。
CSVファイルから前処理をする
以下のCSVファイルから画像ファイルの名前とラベルを取得して、カテゴリ分類する例を考えてみましょう。最初の一行目に各カラムの説明が書いてあり、次の行からデータの中身が入力されています。
filename,label
001.jpg,rock
002.jpg,river
003.jpg,rock
003.jpg,sky
004.jpg,person
005.jpg,sky
006.jpg,animal
007.jpg,river
008.jpg,rock
009.jpg,rock
010.jpg,sky
CSVファイルを読むには、CSVファイルが小さなデータセットであればメモリに入れてfrom_tensors
かfrom_tensor_slices
しても問題ないですが、大きい場合にも対応可能なようにdata.TextLineDataset
を使用します。
import tensorflow as tf
import tensorflow.contrib.data as data
sess = tf.InteractiveSession()
categories = ['rock', 'river', 'sky', 'person', 'animal']
def to_index(label):
return categories.index(label)
def parse_csv(line):
[filename, category] = line.decode('utf-8').split(',')
return filename.encode('utf-8'), to_index(category)
def read_image(filename, label):
contents = tf.read_file(filename)
return tf.image.decode_image(contents), label
dataset = data.TextLineDataset("/tmp/test.csv")\
.skip(1)\
.map(lambda x: tf.py_func(parse_csv, [x], [tf.string, tf.int64]))\
.map(read_image)\
.shuffle(4)\
.batch(4)
iterator = dataset.make_one_shot_iterator()
next_elem = iterator.get_next()
print(sess.run(next_elem))
tf.py_func
を使用することで、上記のようにPythonの関数を呼び出すことができるようになります。
TFRecordを読み込む実用例
TFRecordを作成すると、GPUとCPUを並列に動作させることができるので高速になります。TFRecordもtf.contrib.data
モジュールの関数で読み込むことができます。以下のようにパース関数を書いてmap
で変換します。
TFRecordについては、以下の記事を参考にしてください。
TensorFlowのデータフォーマットTFRecordの書き込みと読み込み方法 /tensorflow/2017/10/07/tfrecord.html
import tensorflow as tf
import tensorflow.contrib.data as data
sess = tf.InteractiveSession()
def parse_function(example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int64, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features['image'], parsed_features['label']
def read_image(image_raw, label):
image = tf.decode_raw(image_raw, tf.uint8)
return image, label
dataset = data.TFRecordDataset("./test.tfrecords")\
.map(parse_function)\
.map(read_image)\
.shuffle(4)\
.batch(4)
iterator = dataset.make_one_shot_iterator()
next_elem = iterator.get_next()
print(sess.run(next_elem))
まとめ
TensorFlowのDataset APIはデータセットを簡単に読み込めるようにした新しいモジュールです。これまでの読み込み関数は廃止になる可能性もありますし、このモジュールの関数が大幅に変更になるかもしれません。
個人的にはこれまでのAPIと比較すると、使いやすい印象があります。
さらに使いやすく、ハイパフォーマンスになると良いですね。
参考
[1] Reading data
[2] Using the Dataset API for TensorFlow Input Pipelines