Skip to content
This repository was archived by the owner on Jun 19, 2025. It is now read-only.

Commit fadaf2a

Browse files
author
Daniel
committed
Freeze layers for transfer learning.
1 parent 4e55d63 commit fadaf2a

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

training/deepspeech_training/train.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,35 @@ def get_tower_results(iterator, optimizer, dropout_rates):
322322
# Retain tower's avg losses
323323
tower_avg_losses.append(avg_loss)
324324

325+
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
326+
327+
# Filter out layers if we want to freeze some
328+
if FLAGS.freeze_source_layers > 0:
329+
filter_vars = []
330+
if FLAGS.freeze_source_layers <= 5:
331+
filter_vars.append("layer_1")
332+
if FLAGS.freeze_source_layers <= 4:
333+
filter_vars.append("layer_2")
334+
if FLAGS.freeze_source_layers <= 3:
335+
filter_vars.append("layer_3")
336+
if FLAGS.freeze_source_layers <= 2:
337+
filter_vars.append("lstm")
338+
if FLAGS.freeze_source_layers <= 1:
339+
filter_vars.append("layer_5")
340+
341+
new_train_vars = list(train_vars)
342+
for fv in filter_vars:
343+
for tv in train_vars:
344+
if fv in tv.name:
345+
new_train_vars.remove(tv)
346+
train_vars = new_train_vars
347+
msg = "Tower {} - Training only variables: {}"
348+
print(msg.format(i, [v.name for v in train_vars]))
349+
else:
350+
print("Tower {} - Training all layers".format(i))
351+
325352
# Compute gradients for model parameters using tower's mini-batch
326-
gradients = optimizer.compute_gradients(avg_loss)
353+
gradients = optimizer.compute_gradients(avg_loss, var_list=train_vars)
327354

328355
# Retain tower's gradients
329356
tower_gradients.append(gradients)
@@ -671,7 +698,6 @@ def __call__(self, progress, data, **kwargs):
671698

672699
print('-' * 80)
673700

674-
675701
except KeyboardInterrupt:
676702
pass
677703
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))

training/deepspeech_training/util/checkpoints.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
4646
'tensors. Missing variables: {}'.format(missing_var_names))
4747
sys.exit(1)
4848

49+
if FLAGS.load_frozen_graph:
50+
# After training with "freeze_source_layers" the Adam tensors for the frozen layers aren't
51+
# existing anymore because they were not used
52+
# Therefore we have to initialize them again to continue training on such checkpoints
53+
for v in load_vars:
54+
if v.op.name not in vars_in_ckpt:
55+
if 'Adam' in v.name:
56+
init_vars.add(v)
57+
else:
58+
msg = "Tried to load a frozen checkpoint but there was a missing " \
59+
"variable other than the Adam tensors: {}"
60+
log_error(msg.format(v))
61+
sys.exit(1)
62+
load_vars -= init_vars
63+
4964
if allow_drop_layers and FLAGS.drop_source_layers > 0:
5065
# This transfer learning approach requires supplying
5166
# the layers which we exclude from the source model.

training/deepspeech_training/util/flags.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def create_flags():
9393
# Transfer Learning
9494

9595
f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)')
96+
f.DEFINE_integer('freeze_source_layers', 0, 'use same value as above to freeze the other layers')
97+
f.DEFINE_boolean('load_frozen_graph', False, 'Needed to load a graph checkpoint which was trained with "freeze_source_layers" flag. Allows initialization of missing optimization tensors.')
9698

9799
# Exporting
98100

0 commit comments

Comments
 (0)