From 988f9a1dda66f3a914928ddb05824661bef43ab4 Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Mon, 21 Feb 2022 14:10:18 -0800 Subject: [PATCH] Add slight speed improvement by avoiding tiling and reshapes. We implicitly tile with the BE's first einsum op. Instead of reshape, we carry around the additional ensemble size axis, and where all layers operate with multiple batch dimensions. PiperOrigin-RevId: 430084350 --- edward2/jax/nn/dense.py | 52 ++++++++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/edward2/jax/nn/dense.py b/edward2/jax/nn/dense.py index 8aa54d24..11be6207 100644 --- a/edward2/jax/nn/dense.py +++ b/edward2/jax/nn/dense.py @@ -47,14 +47,19 @@ class DenseBatchEnsemble(nn.Module): bias_init: InitializeFn = nn.initializers.zeros @nn.compact - def __call__(self, inputs): + def __call__(self, inputs, + is_first_be: bool = False, + index: Optional[int] = None): """Applies layer to input. Args: - inputs: jnp.ndarray of shape [ens_size * batch_size, ..., input_dim]. + inputs: jnp.ndarray of shape [ens_size, batch_size, ..., input_dim] or + [batch_size, ..., input_dim] if is_first_be. + is_first_be: + index: Returns: - jnp.ndarray of shape [ens_size * batch_size, ..., features]. + jnp.ndarray of shape [ens_size, batch_size, ..., features]. """ dtype = self.dtype or inputs.dtype inputs = jnp.asarray(inputs, dtype) @@ -67,18 +72,39 @@ def __call__(self, inputs): gamma = self.param('fast_weight_gamma', self.gamma_init, (self.ens_size, self.features), dtype) - inputs_shape = inputs.shape - inputs = jnp.reshape(inputs, (self.ens_size, -1) + inputs_shape[1:]) - outputs = jnp.einsum('E...C,EC,CD,ED->E...D', inputs, alpha, kernel, gamma) + if index is None: + if not is_first_be: + outputs = jnp.einsum( + 'E...C,EC,CD,ED->E...D', inputs, alpha, kernel, gamma) + else: + # TODO(trandustin): Testing einsum instead of tile. + outputs = jnp.einsum('...C,EC,CD,ED->E...D', inputs, alpha, kernel, gamma) - if self.use_bias: - bias = self.param('bias', self.bias_init, (self.ens_size, self.features), - dtype) - bias_shape = (self.ens_size,) + (1,) * (outputs.ndim - 2) + ( - self.features,) - outputs = outputs + jnp.reshape(bias, bias_shape) + if self.use_bias: + bias = self.param('bias', self.bias_init, (self.ens_size, self.features), + dtype) + bias_shape = (self.ens_size,) + (1,) * (outputs.ndim - 2) + ( + self.features,) + # TODO(trandustin): When finetuned from a deterministic upstream ckpt, + # need to enable setting bias to use ens_size=1 version as below so it can + # be set to deterministic's bias. Or ensemble of biases should at least be + # initialized that way. + # bias = self.param('bias', self.bias_init, (self.features,), dtype) + # bias_shape = (1,) * (outputs.ndim - 1) + (self.features,) + outputs = outputs + jnp.reshape(bias, bias_shape) + else: + alphai = alpha[index] + gammai = gamma[index] + outputs = jnp.einsum('...C,C,CD,D->...D', inputs, alphai, kernel, gammai) + if self.use_bias: + # For stochastic BE, we use a shared bias across members for now. This + # makes bias training easier although less diversity that may be + # important. + bias = self.param('bias', self.bias_init, (self.features,), dtype) + bias_shape = (1,) * (outputs.ndim - 1) + (self.features,) + outputs = outputs + jnp.reshape(bias, bias_shape) if self.activation is not None: outputs = self.activation(outputs) # pylint: disable=not-callable - return jnp.reshape(outputs, inputs_shape[:-1] + (self.features,)) + return outputs