Skip to content

Commit c5b3205

Browse files
jereliucopybara-github
authored andcommitted
Removes unnecessary ViT-GP hyper-parameters.
Due to [pull #489](google/edward2#489) to `edward2.jax.nn.RandomFeatureGaussianProcess`. Some of the special hyper-parameter configs are no longer needed. Therefore we remove them to simplify the model API. PiperOrigin-RevId: 388484029
1 parent 6fb3245 commit c5b3205

File tree

3 files changed

+31
-45
lines changed

3 files changed

+31
-45
lines changed

baselines/jft/experiments/jft300m_vit_base16_sngp.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,17 @@ def get_config():
4040

4141
pp_common = '|value_range(-1, 1)'
4242
pp_common += f'|onehot({config.num_classes})'
43-
# To use ancestor "smearing", use this line instead:
44-
# pp_common += f'|onehot({config.num_classes}, key="labels_extended", key_result="labels") # pylint: disable=line-too-long
43+
# To use ancestor 'smearing', use this line instead:
44+
# pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels') # pylint: disable=line-too-long
4545
pp_common += '|keep("image", "labels")'
4646
config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
4747
config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
4848
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
4949

5050
config.log_training_steps = 50
5151
config.log_eval_steps = 1000
52-
# NOTE: eval is very fast O(seconds) so it's fine to run it often.
53-
config.checkpoint_steps = 1000
52+
# NOTE: For pretraining, save infrequently to prevent crowding diskspace.
53+
config.checkpoint_steps = 517790
5454

5555
# Model section
5656
config.model = ml_collections.ConfigDict()
@@ -66,11 +66,11 @@ def get_config():
6666
config.model.classifier = 'token' # Or 'gap'
6767
config.model.representation_size = 768
6868

69-
# GP layer parameters.
69+
# Gaussian process layer parameters.
7070
config.gp_layer = ml_collections.ConfigDict()
71-
config.gp_layer.normalize_input = True
72-
config.gp_layer.random_feature_scale = 1. # 1. or None
73-
config.gp_layer.random_feature_stddev = 0.025 # 1. or 0.025
71+
# Use momentum for pre-training to prevent numeric error when inverting a
72+
# precision matrix accumulated over 300M data.
73+
config.gp_layer.covmat_momentum = .999
7474

7575
# Optimizer section
7676
config.optim_name = 'Adam'
@@ -82,7 +82,8 @@ def get_config():
8282

8383
# TODO(lbeyer): make a mini-language like preprocessings.
8484
config.lr = ml_collections.ConfigDict()
85-
config.lr.base = 8e-4 # LR has to be lower for larger models!
85+
# LR has to be lower for GP layer and on larger models.
86+
config.lr.base = 4e-4
8687
config.lr.warmup_steps = 10_000
8788
config.lr.decay_type = 'linear'
8889
config.lr.linear_end = 1e-5
@@ -96,9 +97,4 @@ def get_config():
9697

9798

9899
def get_sweep(hyper):
99-
# lr_grid = [3e-4, 4e-4, 5e-4, 6e-4]
100-
# stddev_grid = [0.01, 0.02, 0.03, 0.04, 0.05]
101-
return hyper.product([
102-
# hyper.sweep('config.lr.base', lr_grid),
103-
# hyper.sweep('config.gp_layer.random_feature_stddev', stddev_grid)
104-
])
100+
return hyper.product([])

baselines/jft/sngp.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from clu import periodic_actions
2828
import flax
2929
import flax.jax_utils as flax_utils
30-
import flax.linen as nn
3130
import jax
3231
import jax.numpy as jnp
3332
import ml_collections
@@ -72,7 +71,7 @@ def accumulate_gradient_with_states(
7271
accum_steps):
7372
"""Improved version of `u.accumulate_gradient()` that allows for states."""
7473
# This function handles the `loss_and_grad_fn` function which takes a state
75-
# arguement and returns ((losses, states), grads).
74+
# argument and returns ((losses, states), grads).
7675
if accum_steps and accum_steps > 1:
7776
assert images.shape[0] % accum_steps == 0, (
7877
f'Bad accum_steps {accum_steps} for batch size {images.shape[0]}')
@@ -102,27 +101,16 @@ def acc_grad_and_loss(i, l_s_g):
102101

103102

104103
def get_gp_kwargs(gp_config):
105-
"""Extract keyword arguement parameters for the Gaussian process layer."""
106-
normalize_input = gp_config.get('normalize_input', True)
107-
kernel_stddev = gp_config.get('random_feature_stddev', 1.)
108-
feature_scale = gp_config.get('random_feature_scale', -1.)
104+
"""Extract keyword argument parameters for the Gaussian process layer."""
109105
covmat_momentum = gp_config.get('covmat_momentum', 0.999)
110106

111-
logging.info('gp_config.normalize_input = %s', normalize_input)
112-
logging.info('gp_config.random_feature_stddev = %s', kernel_stddev)
113-
logging.info('gp_config.random_feature_scale = %s', feature_scale)
107+
# Extracts model parameter.
114108
logging.info('gp_config.covmat_momentum = %s', covmat_momentum)
115-
116-
feature_scale = None if feature_scale < 0. else feature_scale
117-
kernel_init = nn.initializers.normal(stddev=kernel_stddev)
118-
hidden_kwargs = dict(feature_scale=feature_scale, kernel_init=kernel_init)
109+
covmat_momentum = None if covmat_momentum < 0. else covmat_momentum
119110
covmat_kwargs = dict(momentum=covmat_momentum)
120111

121-
# Assemble into kwargs dictionary.
122-
gp_layer_kwargs = dict(
123-
normalize_input=normalize_input,
124-
hidden_kwargs=hidden_kwargs,
125-
covmat_kwargs=covmat_kwargs)
112+
# Assembles into kwargs dictionary.
113+
gp_layer_kwargs = dict(covmat_kwargs=covmat_kwargs)
126114

127115
return gp_layer_kwargs
128116

@@ -337,7 +325,7 @@ def representation_fn(params, images, labels, mask, states):
337325
@partial(jax.pmap, axis_name='batch', donate_argnums=(0,))
338326
def update_fn(opt, states, lr, images, labels, rng):
339327
"""Update step."""
340-
328+
# TODO(jereliu): Expand to allow precision matrix resetting.
341329
measurements = {}
342330

343331
if config.get('mixup') and config.mixup.p:
@@ -423,17 +411,17 @@ def decay_fn(v, wd):
423411
checkpoint['states'],
424412
checkpoint['extra'])
425413
elif config.get('model_init'):
426-
write_note(f'Initialize model from {config.model_init}...')
427-
raise ValueError(
428-
'Load from `config.model_init` checkpoint is currently not supported.')
414+
# Load trainable parameters from the checkpoint.
415+
# This does not cause issue for SNGP since all non-trainable parameters
416+
# (random feature, precision matrix, etc) are last-layer parameters that
417+
# should be re-trained during fine-tuning.
418+
write_note(f'Initialize trainable parameters from {config.model_init}...')
429419
# TODO(dusenberrymw): Replace and test load function.
430-
# pylint:disable=unreachable
431420
loaded = resformer.load(params_cpu, config.model_init, config.get('model'))
432421
opt_cpu = opt_cpu.replace(target=loaded)
433422
if jax.host_id() == 0:
434423
logging.info('Restored parameter overview:')
435424
parameter_overview.log_parameter_overview(loaded)
436-
# pylint:enable=unreachable
437425

438426
write_note('Kicking off misc stuff...')
439427
first_step = int(opt_cpu.state.step) # Might be a DeviceArray type.
@@ -482,6 +470,7 @@ def decay_fn(v, wd):
482470
mw.step_start(step)
483471

484472
with jax.profiler.TraceContext('train_step', step_num=step, _r=1):
473+
# TODO(jereliu): Expand to allow precision matrix resetting.
485474
(opt_repl, states_repl, loss_value, rngs_loop,
486475
extra_measurements) = update_fn(
487476
opt_repl,
@@ -505,8 +494,9 @@ def decay_fn(v, wd):
505494
# alive while they'll be updated in a future step, creating hard to debug
506495
# memory errors (see b/160593526). Also, takes device 0's params only.
507496
# We will also do the same for untrainable parameters (`states`). This is
508-
# ok since both `random features` and `predictive covariance` are frozen
509-
# or task-specific parameters that are not important for pre-training.
497+
# ok since `random features` are frozen throughout pre-training, and
498+
# `predictive covariance` are irrelevant for downstream finetuning and
499+
# will be discarded anyway.
510500
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
511501
states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)
512502

baselines/jft/sngp_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ def get_config(classifier, representation_size):
115115
class SNGPTest(parameterized.TestCase, tf.test.TestCase):
116116

117117
@parameterized.parameters(
118-
('token', 2, 1111.4404296875, 16258.519965277777, 0.16999999806284904),
119-
('token', None, 13992.8515625, 3621.3713107638887, 0.20999999344348907),
120-
('gap', 2, 8779.61328125, 3998.798285590278, 0.12999999895691872),
121-
('gap', None, 11279.3515625, 3212.2536892361113, 0.2199999988079071),
118+
('token', 2, 916.2851, 1954.3369140625, 0.16999999806284904),
119+
('token', None, 290.0307, 915.987548828125, 0.20999999344348907),
120+
('gap', 2, 695.6460, 600.8613823784722, 0.12999999895691872),
121+
('gap', None, 192.9434, 341.7078450520833, 0.2199999988079071),
122122
)
123123
def test_sngp_script(self, classifier, representation_size,
124124
correct_train_loss, correct_val_loss,

0 commit comments

Comments
 (0)