Tensorflow Estimatorでモデルの保存と読み出し

スポンサーリンク
Python

みなさんこんにちは!

短いですが個人的に答えにたどり着くのに時間が掛かったことを書き留めておきます。

モデルの保存:引数model_dirに保存される

DNNClassifierの場合、

tf.estimator.DNNClassifier(
    hidden_units, feature_columns, model_dir=None, n_classes=2, weight_column=None,
    label_vocabulary=None, optimizer='Adagrad', activation_fn=tf.nn.relu,
    dropout=None, config=None, warm_start_from=None,
    loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, batch_norm=False
)

(マニュアル)https://www.tensorflow.org/api_docs/python/tf/estimator/DNNClassifier

引数にあるmodel_dirにはパスを指定し、こちらがモデルの保存先になります。

保存等のアクションは明示する必要はなく、trainやevaluateを通して学習や評価をすると自動でcheckpointと呼ばれるモデルの状態が保存されます。

モデルの読み出し:引数model_dirにcheckpoint保存先を指定

model_dirDirectory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model.
model_dirの概要

こちらに説明があるように、一度指定していたmodel_dirを次回も指定すれば保存されたモデル(checkpoint)が読み込まれ保存済みモデルとして予測等ができるようになります。

まとめ

例えばXGBoostですと明示的に保存/読み込みをする関数を書かないと保存されませんので、おそらく同じように関数が用意されているものかと思いそのように調べていたため時間が掛かってしまいました。

他の方の役に立てば嬉しいです。

コメント