This repository was archived by the owner on Dec 29, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
This repository was archived by the owner on Dec 29, 2022. It is now read-only.
seq2seq checkpoint restore for transfer learning #356
Copy link
Copy link
Open
Description
I am using code built on top of train.py and infer.py (files unchanged) from the seq2seq tutorial. I want to do transfer learning/loading from checkpoints but am unfamiliar with the tf.contrib.learn.Estimator and seq2seq.contrib.experiment environment.
I basically want to incorporate the checkpoint load step from infer.py into training:
saver = tf.train.Saver()
checkpoint_path = FLAGS.checkpoint_path
if not checkpoint_path:
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
def session_init_op(_scaffold, sess):
saver.restore(sess, checkpoint_path)
tf.logging.info("Restored model from %s", checkpoint_path)
How/where along the pipeline should I be inserting the script?
def create_experiment(output_dir):
"""
Creates a new Experiment instance.
Args:
output_dir: Output directory for model checkpoints and summaries.
"""
config = run_config.RunConfig(
tf_random_seed=FLAGS.tf_random_seed,
save_checkpoints_secs=FLAGS.save_checkpoints_secs,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config.tf_config.gpu_options.allow_growth = FLAGS.gpu_allow_growth
config.tf_config.log_device_placement = FLAGS.log_device_placement
train_options = training_utils.TrainOptions(
model_class=FLAGS.model,
model_params=FLAGS.model_params)
# On the main worker, save training options
if config.is_chief:
gfile.MakeDirs(output_dir)
train_options.dump(output_dir)
bucket_boundaries = None
if FLAGS.buckets:
bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))
# Training data input pipeline
train_input_pipeline = input_pipeline.make_input_pipeline_from_def(
def_dict=FLAGS.input_pipeline_train,
mode=tf.contrib.learn.ModeKeys.TRAIN)
# Create training input function
train_input_fn = training_utils.create_input_fn(
pipeline=train_input_pipeline,
batch_size=FLAGS.batch_size,
bucket_boundaries=bucket_boundaries,
scope="train_input_fn")
# Development data input pipeline
dev_input_pipeline = input_pipeline.make_input_pipeline_from_def(
def_dict=FLAGS.input_pipeline_dev,
mode=tf.contrib.learn.ModeKeys.EVAL,
shuffle=False, num_epochs=1)
# Create eval input function
eval_input_fn = training_utils.create_input_fn(
pipeline=dev_input_pipeline,
batch_size=FLAGS.batch_size,
allow_smaller_final_batch=True,
scope="dev_input_fn")
def model_fn(features, labels, params, mode):
"""Builds the model graph"""
model = _create_from_dict({
"class": train_options.model_class,
"params": train_options.model_params
}, models, mode=mode)
return model(features, labels, params)
estimator = tf.contrib.learn.Estimator(
model_fn=model_fn,
model_dir=output_dir,
config=config,
params=FLAGS.model_params)
# Create hooks
train_hooks = []
for dict_ in FLAGS.hooks:
hook = _create_from_dict(
dict_, hooks,
model_dir=estimator.model_dir,
run_config=config)
train_hooks.append(hook)
# Create metrics
eval_metrics = {}
for dict_ in FLAGS.metrics:
metric = _create_from_dict(dict_, metric_specs)
eval_metrics[metric.name] = metric
saver = tf.train.Saver()
checkpoint_path = FLAGS.checkpoint_path
if not checkpoint_path:
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
saver.restore(sess, checkpoint_path)
## what is PatchedExperiment
experiment = PatchedExperiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
min_eval_frequency=FLAGS.eval_every_n_steps,
train_steps=FLAGS.train_steps,
eval_steps=None,
eval_metrics=eval_metrics,
train_monitors=train_hooks)
return experiment
Metadata
Metadata
Assignees
Labels
No labels