Skip to content
This repository was archived by the owner on Jun 19, 2025. It is now read-only.

Commit 459af0f

Browse files
author
Daniel
committed
Freeze layers for transfer learning.
1 parent a6f40a3 commit 459af0f

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

training/mozilla_voice_stt_training/train.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,35 @@ def get_tower_results(iterator, optimizer, dropout_rates):
317317
# Retain tower's avg losses
318318
tower_avg_losses.append(avg_loss)
319319

320+
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
321+
322+
# Filter out layers if we want to freeze some
323+
if FLAGS.freeze_source_layers > 0:
324+
filter_vars = []
325+
if FLAGS.freeze_source_layers <= 5:
326+
filter_vars.append("layer_1")
327+
if FLAGS.freeze_source_layers <= 4:
328+
filter_vars.append("layer_2")
329+
if FLAGS.freeze_source_layers <= 3:
330+
filter_vars.append("layer_3")
331+
if FLAGS.freeze_source_layers <= 2:
332+
filter_vars.append("lstm")
333+
if FLAGS.freeze_source_layers <= 1:
334+
filter_vars.append("layer_5")
335+
336+
new_train_vars = list(train_vars)
337+
for fv in filter_vars:
338+
for tv in train_vars:
339+
if fv in tv.name:
340+
new_train_vars.remove(tv)
341+
train_vars = new_train_vars
342+
msg = "Tower {} - Training only variables: {}"
343+
print(msg.format(i, [v.name for v in train_vars]))
344+
else:
345+
print("Tower {} - Training all layers".format(i))
346+
320347
# Compute gradients for model parameters using tower's mini-batch
321-
gradients = optimizer.compute_gradients(avg_loss)
348+
gradients = optimizer.compute_gradients(avg_loss, var_list=train_vars)
322349

323350
# Retain tower's gradients
324351
tower_gradients.append(gradients)
@@ -654,7 +681,6 @@ def __call__(self, progress, data, **kwargs):
654681

655682
print('-' * 80)
656683

657-
658684
except KeyboardInterrupt:
659685
pass
660686
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))

training/mozilla_voice_stt_training/util/checkpoints.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,21 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
4545
'tensors. Missing variables: {}'.format(missing_var_names))
4646
sys.exit(1)
4747

48+
if FLAGS.load_frozen_graph:
49+
# After training with "freeze_source_layers" the Adam tensors for the frozen layers aren't
50+
# existing anymore because they were not used
51+
# Therefore we have to initialize them again to continue training on such checkpoints
52+
for v in load_vars:
53+
if v.op.name not in vars_in_ckpt:
54+
if 'Adam' in v.name:
55+
init_vars.add(v)
56+
else:
57+
msg = "Tried to load a frozen checkpoint but there was a missing " \
58+
"variable other than the Adam tensors: {}"
59+
log_error(msg.format(v))
60+
sys.exit(1)
61+
load_vars -= init_vars
62+
4863
if allow_drop_layers and FLAGS.drop_source_layers > 0:
4964
# This transfer learning approach requires supplying
5065
# the layers which we exclude from the source model.

training/mozilla_voice_stt_training/util/flags.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def create_flags():
9393
# Transfer Learning
9494

9595
f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)')
96+
f.DEFINE_integer('freeze_source_layers', 0, 'use same value as above to freeze the other layers')
97+
f.DEFINE_boolean('load_frozen_graph', False, 'Needed to load a graph checkpoint which was trained with "freeze_source_layers" flag. Allows initialization of missing optimization tensors.')
9698

9799
# Exporting
98100

0 commit comments

Comments
 (0)