2727from clu import periodic_actions
2828import flax
2929import flax .jax_utils as flax_utils
30- import flax .linen as nn
3130import jax
3231import jax .numpy as jnp
3332import ml_collections
@@ -72,7 +71,7 @@ def accumulate_gradient_with_states(
7271 accum_steps ):
7372 """Improved version of `u.accumulate_gradient()` that allows for states."""
7473 # This function handles the `loss_and_grad_fn` function which takes a state
75- # arguement and returns ((losses, states), grads).
74+ # argument and returns ((losses, states), grads).
7675 if accum_steps and accum_steps > 1 :
7776 assert images .shape [0 ] % accum_steps == 0 , (
7877 f'Bad accum_steps { accum_steps } for batch size { images .shape [0 ]} ' )
@@ -102,27 +101,16 @@ def acc_grad_and_loss(i, l_s_g):
102101
103102
104103def get_gp_kwargs (gp_config ):
105- """Extract keyword arguement parameters for the Gaussian process layer."""
106- normalize_input = gp_config .get ('normalize_input' , True )
107- kernel_stddev = gp_config .get ('random_feature_stddev' , 1. )
108- feature_scale = gp_config .get ('random_feature_scale' , - 1. )
104+ """Extract keyword argument parameters for the Gaussian process layer."""
109105 covmat_momentum = gp_config .get ('covmat_momentum' , 0.999 )
110106
111- logging .info ('gp_config.normalize_input = %s' , normalize_input )
112- logging .info ('gp_config.random_feature_stddev = %s' , kernel_stddev )
113- logging .info ('gp_config.random_feature_scale = %s' , feature_scale )
107+ # Extracts model parameter.
114108 logging .info ('gp_config.covmat_momentum = %s' , covmat_momentum )
115-
116- feature_scale = None if feature_scale < 0. else feature_scale
117- kernel_init = nn .initializers .normal (stddev = kernel_stddev )
118- hidden_kwargs = dict (feature_scale = feature_scale , kernel_init = kernel_init )
109+ covmat_momentum = None if covmat_momentum < 0. else covmat_momentum
119110 covmat_kwargs = dict (momentum = covmat_momentum )
120111
121- # Assemble into kwargs dictionary.
122- gp_layer_kwargs = dict (
123- normalize_input = normalize_input ,
124- hidden_kwargs = hidden_kwargs ,
125- covmat_kwargs = covmat_kwargs )
112+ # Assembles into kwargs dictionary.
113+ gp_layer_kwargs = dict (covmat_kwargs = covmat_kwargs )
126114
127115 return gp_layer_kwargs
128116
@@ -337,7 +325,7 @@ def representation_fn(params, images, labels, mask, states):
337325 @partial (jax .pmap , axis_name = 'batch' , donate_argnums = (0 ,))
338326 def update_fn (opt , states , lr , images , labels , rng ):
339327 """Update step."""
340-
328+ # TODO(jereliu): Expand to allow precision matrix resetting.
341329 measurements = {}
342330
343331 if config .get ('mixup' ) and config .mixup .p :
@@ -423,17 +411,17 @@ def decay_fn(v, wd):
423411 checkpoint ['states' ],
424412 checkpoint ['extra' ])
425413 elif config .get ('model_init' ):
426- write_note (f'Initialize model from { config .model_init } ...' )
427- raise ValueError (
428- 'Load from `config.model_init` checkpoint is currently not supported.' )
414+ # Load trainable parameters from the checkpoint.
415+ # This does not cause issue for SNGP since all non-trainable parameters
416+ # (random feature, precision matrix, etc) are last-layer parameters that
417+ # should be re-trained during fine-tuning.
418+ write_note (f'Initialize trainable parameters from { config .model_init } ...' )
429419 # TODO(dusenberrymw): Replace and test load function.
430- # pylint:disable=unreachable
431420 loaded = resformer .load (params_cpu , config .model_init , config .get ('model' ))
432421 opt_cpu = opt_cpu .replace (target = loaded )
433422 if jax .host_id () == 0 :
434423 logging .info ('Restored parameter overview:' )
435424 parameter_overview .log_parameter_overview (loaded )
436- # pylint:enable=unreachable
437425
438426 write_note ('Kicking off misc stuff...' )
439427 first_step = int (opt_cpu .state .step ) # Might be a DeviceArray type.
@@ -482,6 +470,7 @@ def decay_fn(v, wd):
482470 mw .step_start (step )
483471
484472 with jax .profiler .TraceContext ('train_step' , step_num = step , _r = 1 ):
473+ # TODO(jereliu): Expand to allow precision matrix resetting.
485474 (opt_repl , states_repl , loss_value , rngs_loop ,
486475 extra_measurements ) = update_fn (
487476 opt_repl ,
@@ -505,8 +494,9 @@ def decay_fn(v, wd):
505494 # alive while they'll be updated in a future step, creating hard to debug
506495 # memory errors (see b/160593526). Also, takes device 0's params only.
507496 # We will also do the same for untrainable parameters (`states`). This is
508- # ok since both `random features` and `predictive covariance` are frozen
509- # or task-specific parameters that are not important for pre-training.
497+ # ok since `random features` are frozen throughout pre-training, and
498+ # `predictive covariance` are irrelevant for downstream finetuning and
499+ # will be discarded anyway.
510500 opt_cpu = jax .tree_map (lambda x : np .array (x [0 ]), opt_repl )
511501 states_cpu = jax .tree_map (lambda x : np .array (x [0 ]), states_repl )
512502
0 commit comments