Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions edward2/jax/nn/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,19 @@ class DenseBatchEnsemble(nn.Module):
bias_init: InitializeFn = nn.initializers.zeros

@nn.compact
def __call__(self, inputs):
def __call__(self, inputs,
is_first_be: bool = False,
index: Optional[int] = None):
"""Applies layer to input.

Args:
inputs: jnp.ndarray of shape [ens_size * batch_size, ..., input_dim].
inputs: jnp.ndarray of shape [ens_size, batch_size, ..., input_dim] or
[batch_size, ..., input_dim] if is_first_be.
is_first_be:
index:

Returns:
jnp.ndarray of shape [ens_size * batch_size, ..., features].
jnp.ndarray of shape [ens_size, batch_size, ..., features].
"""
dtype = self.dtype or inputs.dtype
inputs = jnp.asarray(inputs, dtype)
Expand All @@ -67,18 +72,39 @@ def __call__(self, inputs):
gamma = self.param('fast_weight_gamma', self.gamma_init,
(self.ens_size, self.features), dtype)

inputs_shape = inputs.shape
inputs = jnp.reshape(inputs, (self.ens_size, -1) + inputs_shape[1:])
outputs = jnp.einsum('E...C,EC,CD,ED->E...D', inputs, alpha, kernel, gamma)
if index is None:
if not is_first_be:
outputs = jnp.einsum(
'E...C,EC,CD,ED->E...D', inputs, alpha, kernel, gamma)
else:
# TODO(trandustin): Testing einsum instead of tile.
outputs = jnp.einsum('...C,EC,CD,ED->E...D', inputs, alpha, kernel, gamma)

if self.use_bias:
bias = self.param('bias', self.bias_init, (self.ens_size, self.features),
dtype)
bias_shape = (self.ens_size,) + (1,) * (outputs.ndim - 2) + (
self.features,)
outputs = outputs + jnp.reshape(bias, bias_shape)
if self.use_bias:
bias = self.param('bias', self.bias_init, (self.ens_size, self.features),
dtype)
bias_shape = (self.ens_size,) + (1,) * (outputs.ndim - 2) + (
self.features,)
# TODO(trandustin): When finetuned from a deterministic upstream ckpt,
# need to enable setting bias to use ens_size=1 version as below so it can
# be set to deterministic's bias. Or ensemble of biases should at least be
# initialized that way.
# bias = self.param('bias', self.bias_init, (self.features,), dtype)
# bias_shape = (1,) * (outputs.ndim - 1) + (self.features,)
outputs = outputs + jnp.reshape(bias, bias_shape)
else:
alphai = alpha[index]
gammai = gamma[index]
outputs = jnp.einsum('...C,C,CD,D->...D', inputs, alphai, kernel, gammai)
if self.use_bias:
# For stochastic BE, we use a shared bias across members for now. This
# makes bias training easier although less diversity that may be
# important.
bias = self.param('bias', self.bias_init, (self.features,), dtype)
bias_shape = (1,) * (outputs.ndim - 1) + (self.features,)
outputs = outputs + jnp.reshape(bias, bias_shape)

if self.activation is not None:
outputs = self.activation(outputs) # pylint: disable=not-callable

return jnp.reshape(outputs, inputs_shape[:-1] + (self.features,))
return outputs