diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 686c6ebde..d6565df68 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2683,8 +2683,11 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: + log_p = -jnp.log(self.high - self.low) + is_in_support = (value >= self.low) & (value < self.high) shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) - return -jnp.broadcast_to(jnp.log(self.high - self.low), shape) + log_p = jnp.broadcast_to(log_p, shape) + return jnp.where(is_in_support, log_p, -jnp.inf) def cdf(self, value: ArrayLike) -> ArrayLike: cdf = (value - self.low) / (self.high - self.low) diff --git a/test/test_distributions.py b/test/test_distributions.py index 7edc19226..b98993425 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4486,3 +4486,106 @@ def test_interval_censored_validate_sample( censored_dist.log_prob(value) else: censored_dist.log_prob(value) # Should not raise + + +def test_uniform_log_prob_outside_support(): + d = dist.Uniform(0, 1) + assert_allclose(d.log_prob(-0.5), -jnp.inf) + assert_allclose(d.log_prob(1.5), -jnp.inf) + + +@pytest.mark.parametrize( + "low, high", [(0.0, 1.0), (-2.0, 3.0), (1.0, 5.0), (-5.0, -1.0)] +) +def test_uniform_log_prob_boundaries(low, high): + """Test that boundary values are handled correctly.""" + d = dist.Uniform(low, high) + expected_log_prob = -jnp.log(high - low) + + # Value at lower bound (included): should have finite log prob + assert_allclose(d.log_prob(low), expected_log_prob) + + # Value just above lower bound: should have finite log prob + assert_allclose(d.log_prob(low + 1e-10), expected_log_prob) + + # Value at upper bound (excluded): should be -inf + assert_allclose(d.log_prob(high), -jnp.inf) + + # Value just below upper bound: should have finite log prob + assert_allclose(d.log_prob(high - 1e-10), expected_log_prob) + + # Value inside support: should have finite log prob + mid = (low + high) / 2.0 + assert_allclose(d.log_prob(mid), expected_log_prob) + + # Value below lower bound: should be -inf + assert_allclose(d.log_prob(low - 1.0), -jnp.inf) + + # Value above upper bound: should be -inf + assert_allclose(d.log_prob(high + 1.0), -jnp.inf) + + +@pytest.mark.parametrize("batch_shape", [(), (3,), (2, 3), (4, 2, 3)]) +def test_uniform_log_prob_broadcasting(batch_shape): + """Test broadcasting with different batch shapes.""" + if batch_shape == (): + low = 0.0 + high = 1.0 + else: + low = jnp.linspace(0.0, 1.0, np.prod(batch_shape)).reshape(batch_shape) + high = jnp.linspace(1.0, 2.0, np.prod(batch_shape)).reshape(batch_shape) + + d = dist.Uniform(low, high) + + # Test with scalar value + value = 0.5 + log_probs = d.log_prob(value) + assert log_probs.shape == batch_shape + + # Test with batched value + if batch_shape: + value_batched = jnp.linspace(-0.5, 1.5, np.prod(batch_shape)).reshape( + batch_shape + ) + log_probs_batched = d.log_prob(value_batched) + assert log_probs_batched.shape == batch_shape + + # Check that values outside support return -inf + # Values < low should be -inf + below_low = low - 0.1 + assert_allclose(d.log_prob(below_low), -jnp.inf) + + # Values >= high should be -inf + at_high = high + assert_allclose(d.log_prob(at_high), -jnp.inf) + + +@pytest.mark.parametrize("value_shape", [(), (5,), (3, 4), (2, 3, 4)]) +def test_uniform_log_prob_value_broadcasting(value_shape): + """Test broadcasting when value has different shapes.""" + d = dist.Uniform(0.0, 1.0) + + if value_shape == (): + values = 0.5 + else: + values = jnp.linspace(-0.5, 1.5, np.prod(value_shape)).reshape(value_shape) + + log_probs = d.log_prob(values) + assert log_probs.shape == value_shape + + # Check that values inside support have finite log prob + inside_values = jnp.linspace(0.1, 0.9, np.prod(value_shape) if value_shape else 1) + if value_shape: + inside_values = inside_values.reshape(value_shape) + log_probs_inside = d.log_prob(inside_values) + assert jnp.all(jnp.isfinite(log_probs_inside)) + + # Check that values outside support have -inf + outside_values = jnp.linspace(-1.0, 2.0, np.prod(value_shape) if value_shape else 1) + if value_shape: + outside_values = outside_values.reshape(value_shape) + log_probs_outside = d.log_prob(outside_values) + # Values in [0, 1) should be finite, others should be -inf + mask_inside = (outside_values >= 0.0) & (outside_values < 1.0) + assert jnp.all(jnp.where(mask_inside, jnp.isfinite(log_probs_outside), True)) + assert jnp.all(jnp.where(~mask_inside, log_probs_outside == -jnp.inf, True))