diff --git a/edward2/jax/nn/dense.py b/edward2/jax/nn/dense.py index 8aa54d24..11be6207 100644 --- a/edward2/jax/nn/dense.py +++ b/edward2/jax/nn/dense.py @@ -47,14 +47,19 @@ class DenseBatchEnsemble(nn.Module): bias_init: InitializeFn = nn.initializers.zeros @nn.compact - def __call__(self, inputs): + def __call__(self, inputs, + is_first_be: bool = False, + index: Optional[int] = None): """Applies layer to input. Args: - inputs: jnp.ndarray of shape [ens_size * batch_size, ..., input_dim]. + inputs: jnp.ndarray of shape [ens_size, batch_size, ..., input_dim] or + [batch_size, ..., input_dim] if is_first_be. + is_first_be: + index: Returns: - jnp.ndarray of shape [ens_size * batch_size, ..., features]. + jnp.ndarray of shape [ens_size, batch_size, ..., features]. """ dtype = self.dtype or inputs.dtype inputs = jnp.asarray(inputs, dtype) @@ -67,18 +72,39 @@ def __call__(self, inputs): gamma = self.param('fast_weight_gamma', self.gamma_init, (self.ens_size, self.features), dtype) - inputs_shape = inputs.shape - inputs = jnp.reshape(inputs, (self.ens_size, -1) + inputs_shape[1:]) - outputs = jnp.einsum('E...C,EC,CD,ED->E...D', inputs, alpha, kernel, gamma) + if index is None: + if not is_first_be: + outputs = jnp.einsum( + 'E...C,EC,CD,ED->E...D', inputs, alpha, kernel, gamma) + else: + # TODO(trandustin): Testing einsum instead of tile. + outputs = jnp.einsum('...C,EC,CD,ED->E...D', inputs, alpha, kernel, gamma) - if self.use_bias: - bias = self.param('bias', self.bias_init, (self.ens_size, self.features), - dtype) - bias_shape = (self.ens_size,) + (1,) * (outputs.ndim - 2) + ( - self.features,) - outputs = outputs + jnp.reshape(bias, bias_shape) + if self.use_bias: + bias = self.param('bias', self.bias_init, (self.ens_size, self.features), + dtype) + bias_shape = (self.ens_size,) + (1,) * (outputs.ndim - 2) + ( + self.features,) + # TODO(trandustin): When finetuned from a deterministic upstream ckpt, + # need to enable setting bias to use ens_size=1 version as below so it can + # be set to deterministic's bias. Or ensemble of biases should at least be + # initialized that way. + # bias = self.param('bias', self.bias_init, (self.features,), dtype) + # bias_shape = (1,) * (outputs.ndim - 1) + (self.features,) + outputs = outputs + jnp.reshape(bias, bias_shape) + else: + alphai = alpha[index] + gammai = gamma[index] + outputs = jnp.einsum('...C,C,CD,D->...D', inputs, alphai, kernel, gammai) + if self.use_bias: + # For stochastic BE, we use a shared bias across members for now. This + # makes bias training easier although less diversity that may be + # important. + bias = self.param('bias', self.bias_init, (self.features,), dtype) + bias_shape = (1,) * (outputs.ndim - 1) + (self.features,) + outputs = outputs + jnp.reshape(bias, bias_shape) if self.activation is not None: outputs = self.activation(outputs) # pylint: disable=not-callable - return jnp.reshape(outputs, inputs_shape[:-1] + (self.features,)) + return outputs