diff --git a/edward2/jax/nn/heteroscedastic_lib.py b/edward2/jax/nn/heteroscedastic_lib.py index 6cb7f412..c46e37c0 100644 --- a/edward2/jax/nn/heteroscedastic_lib.py +++ b/edward2/jax/nn/heteroscedastic_lib.py @@ -65,6 +65,7 @@ class MCSoftmaxDenseFA(nn.Module): share_samples_across_batch: bool = False logits_only: bool = False return_locs: bool = False + return_unaveraged_logits: bool = False eps: float = 1e-7 tune_temperature: bool = False temperature_lower_bound: Optional[float] = None @@ -251,9 +252,10 @@ def _compute_mc_samples(self, inputs, scale, num_samples): # [B, S, dim] -> [B, S, K] latents = self._compute_loc_param(latents) # pylint: disable=assignment-from-none - samples = jax.nn.softmax(latents / self.get_temperature()) + scaled_latents = latents / self.get_temperature() + samples = jax.nn.softmax(scaled_latents) - return jnp.mean(samples, axis=1) + return jnp.mean(samples, axis=1), jax.nn.log_softmax(scaled_latents) @nn.compact def __call__(self, inputs, training=True): @@ -278,7 +280,8 @@ def __call__(self, inputs, training=True): else: total_mc_samples = self.test_mc_samples - probs_mean = self._compute_mc_samples(inputs, scale, total_mc_samples) + probs_mean, unaveraged_logits = self._compute_mc_samples( + inputs, scale, total_mc_samples) probs_mean = jnp.clip(probs_mean, a_min=self.eps) log_probs = jnp.log(probs_mean) @@ -288,8 +291,12 @@ def __call__(self, inputs, training=True): logits = self._compute_loc_param(inputs) # pylint: disable=assignment-from-none if self.logits_only: + if self.return_unaveraged_logits: + return logits, unaveraged_logits return logits + if self.return_unaveraged_logits: + return logits, log_probs, probs_mean, unaveraged_logits return logits, log_probs, probs_mean