Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2683,8 +2683,11 @@ def sample(

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
log_p = -jnp.log(self.high - self.low)
is_in_support = (value >= self.low) & (value < self.high)
shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape)
return -jnp.broadcast_to(jnp.log(self.high - self.low), shape)
log_p = jnp.broadcast_to(log_p, shape)
return jnp.where(is_in_support, log_p, -jnp.inf)

def cdf(self, value: ArrayLike) -> ArrayLike:
cdf = (value - self.low) / (self.high - self.low)
Expand Down
103 changes: 103 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4486,3 +4486,106 @@ def test_interval_censored_validate_sample(
censored_dist.log_prob(value)
else:
censored_dist.log_prob(value) # Should not raise


def test_uniform_log_prob_outside_support():
d = dist.Uniform(0, 1)
assert_allclose(d.log_prob(-0.5), -jnp.inf)
assert_allclose(d.log_prob(1.5), -jnp.inf)


@pytest.mark.parametrize(
"low, high", [(0.0, 1.0), (-2.0, 3.0), (1.0, 5.0), (-5.0, -1.0)]
)
def test_uniform_log_prob_boundaries(low, high):
"""Test that boundary values are handled correctly."""
d = dist.Uniform(low, high)
expected_log_prob = -jnp.log(high - low)

# Value at lower bound (included): should have finite log prob
assert_allclose(d.log_prob(low), expected_log_prob)

# Value just above lower bound: should have finite log prob
assert_allclose(d.log_prob(low + 1e-10), expected_log_prob)

# Value at upper bound (excluded): should be -inf
assert_allclose(d.log_prob(high), -jnp.inf)

# Value just below upper bound: should have finite log prob
assert_allclose(d.log_prob(high - 1e-10), expected_log_prob)

# Value inside support: should have finite log prob
mid = (low + high) / 2.0
assert_allclose(d.log_prob(mid), expected_log_prob)

# Value below lower bound: should be -inf
assert_allclose(d.log_prob(low - 1.0), -jnp.inf)

# Value above upper bound: should be -inf
assert_allclose(d.log_prob(high + 1.0), -jnp.inf)


@pytest.mark.parametrize("batch_shape", [(), (3,), (2, 3), (4, 2, 3)])
def test_uniform_log_prob_broadcasting(batch_shape):
"""Test broadcasting with different batch shapes."""
if batch_shape == ():
low = 0.0
high = 1.0
else:
low = jnp.linspace(0.0, 1.0, np.prod(batch_shape)).reshape(batch_shape)
high = jnp.linspace(1.0, 2.0, np.prod(batch_shape)).reshape(batch_shape)

d = dist.Uniform(low, high)

# Test with scalar value
value = 0.5
log_probs = d.log_prob(value)
assert log_probs.shape == batch_shape

# Test with batched value
if batch_shape:
value_batched = jnp.linspace(-0.5, 1.5, np.prod(batch_shape)).reshape(
batch_shape
)
log_probs_batched = d.log_prob(value_batched)
assert log_probs_batched.shape == batch_shape

# Check that values outside support return -inf
# Values < low should be -inf
below_low = low - 0.1
assert_allclose(d.log_prob(below_low), -jnp.inf)

# Values >= high should be -inf
at_high = high
assert_allclose(d.log_prob(at_high), -jnp.inf)


@pytest.mark.parametrize("value_shape", [(), (5,), (3, 4), (2, 3, 4)])
def test_uniform_log_prob_value_broadcasting(value_shape):
"""Test broadcasting when value has different shapes."""
d = dist.Uniform(0.0, 1.0)

if value_shape == ():
values = 0.5
else:
values = jnp.linspace(-0.5, 1.5, np.prod(value_shape)).reshape(value_shape)

log_probs = d.log_prob(values)
assert log_probs.shape == value_shape

# Check that values inside support have finite log prob
inside_values = jnp.linspace(0.1, 0.9, np.prod(value_shape) if value_shape else 1)
if value_shape:
inside_values = inside_values.reshape(value_shape)
log_probs_inside = d.log_prob(inside_values)
assert jnp.all(jnp.isfinite(log_probs_inside))

# Check that values outside support have -inf
outside_values = jnp.linspace(-1.0, 2.0, np.prod(value_shape) if value_shape else 1)
if value_shape:
outside_values = outside_values.reshape(value_shape)
log_probs_outside = d.log_prob(outside_values)
# Values in [0, 1) should be finite, others should be -inf
mask_inside = (outside_values >= 0.0) & (outside_values < 1.0)
assert jnp.all(jnp.where(mask_inside, jnp.isfinite(log_probs_outside), True))
assert jnp.all(jnp.where(~mask_inside, log_probs_outside == -jnp.inf, True))