From 013d898fba12723e9391298da0fb990136778b43 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 27 Oct 2025 13:21:40 +0100 Subject: [PATCH 1/6] fix --- numpyro/distributions/continuous.py | 13 ++++++- test/test_distributions.py | 55 +++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 686c6ebde..647fb1fcd 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -216,7 +216,18 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: - return self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1)) + # Handle edge cases where concentration1=1 and value=0, or concentration0=1 and value=1 + # These cases would result in nan due to log(0) * 0 in the Dirichlet computation + log_prob = jnp.where( + (value == 0.0) & (self.concentration1 == 1.0), + jnp.log(self.concentration0), + jnp.where( + (value == 1.0) & (self.concentration0 == 1.0), + jnp.log(self.concentration1), + self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1)), + ), + ) + return log_prob @property def mean(self) -> ArrayLike: diff --git a/test/test_distributions.py b/test/test_distributions.py index 9942671b2..531913bbd 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4152,3 +4152,58 @@ def test_censored_sample_shape(): ) samples = censored_dist.sample(rng_key, sample_shape) assert samples.shape == expected_shape + + +def test_beta_logprob_edge_case_concentration1_one(): + """Test Beta(1, β) at x=0 should give finite log probability.""" + beta_dist = dist.Beta(1.0, 8.0) + log_prob_at_zero = beta_dist.log_prob(0.0) + + assert not jnp.isnan(log_prob_at_zero), "Beta(1,8).log_prob(0) should not be NaN" + assert jnp.isfinite(log_prob_at_zero), "Beta(1,8).log_prob(0) should be finite" + + +def test_beta_logprob_edge_case_concentration0_one(): + """Test Beta(α, 1) at x=1 should give finite log probability.""" + beta_dist2 = dist.Beta(8.0, 1.0) + log_prob_at_one = beta_dist2.log_prob(1.0) + + assert not jnp.isnan(log_prob_at_one), "Beta(8,1).log_prob(1) should not be NaN" + assert jnp.isfinite(log_prob_at_one), "Beta(8,1).log_prob(1) should be finite" + + +def test_beta_logprob_edge_case_consistency_small_values(): + """Test that edge case values are consistent with small deviation values.""" + beta_dist = dist.Beta(1.0, 8.0) + beta_dist2 = dist.Beta(8.0, 1.0) + + # At boundary + log_prob_at_zero = beta_dist.log_prob(0.0) + log_prob_at_one = beta_dist2.log_prob(1.0) + + # Very close to boundary + small_value = 1e-10 + log_prob_small = beta_dist.log_prob(small_value) + log_prob_close_to_one = beta_dist2.log_prob(1.0 - small_value) + + # Edge case values should be close to small deviation values + assert jnp.abs(log_prob_at_zero - log_prob_small) < 1e-5 + assert jnp.abs(log_prob_at_one - log_prob_close_to_one) < 1e-5 + + +def test_beta_logprob_edge_case_non_boundary_values(): + """Test that Beta with concentration=1 still works for non-boundary values.""" + beta_dist = dist.Beta(1.0, 8.0) + beta_dist2 = dist.Beta(8.0, 1.0) + + assert jnp.isfinite(beta_dist.log_prob(0.5)) + assert jnp.isfinite(beta_dist2.log_prob(0.5)) + + +def test_beta_logprob_boundary_non_edge_cases(): + """Test that non-edge cases (concentration > 1) still give -inf at boundaries.""" + beta_dist3 = dist.Beta(2.0, 8.0) + beta_dist4 = dist.Beta(8.0, 2.0) + + assert jnp.isneginf(beta_dist3.log_prob(0.0)) + assert jnp.isneginf(beta_dist4.log_prob(1.0)) From 13fc288bda45e42840924b8026fb2b5fe5aa01fd Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sat, 1 Nov 2025 19:44:59 +0100 Subject: [PATCH 2/6] add custom gradient --- numpyro/distributions/continuous.py | 85 +++++++++++++++++++++++++---- test/test_distributions.py | 70 ++++++++++++++++++++++++ 2 files changed, 143 insertions(+), 12 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 647fb1fcd..bfbed5fe8 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -179,6 +179,76 @@ def icdf(self, value: ArrayLike) -> ArrayLike: ) +@jax.custom_jvp +def _beta_log_prob(value, concentration1, concentration0): + """ + Compute Beta log probability with custom gradients to handle edge cases. + + When concentration1=1 and value=0, or concentration0=1 and value=1, + the standard formula involves log(0) * 0 which should be 0, but has + undefined gradients. We use custom_jvp to define proper gradients. + """ + return ( + xlogy(concentration1 - 1.0, value) + + xlogy(concentration0 - 1.0, 1.0 - value) + - betaln(concentration1, concentration0) + ) + + +@_beta_log_prob.defjvp +def _beta_log_prob_jvp(primals, tangents): + """ + Define custom JVP (Jacobian-vector product) for Beta log_prob. + + For Beta(α, β), the derivatives are: + - d/dx log_prob = (α - 1) / x - (β - 1) / (1 - x) + - d/dα log_prob = log(x) - ψ(α) + ψ(α + β) + - d/dβ log_prob = log(1 - x) - ψ(β) + ψ(α + β) + + where ψ is the digamma function. Edge cases (α=1, x=0) or (β=1, x=1) are handled + by setting gradients to finite values using safe operations. + """ + value, concentration1, concentration0 = primals + value_dot, concentration1_dot, concentration0_dot = tangents + + primal_out = _beta_log_prob(value, concentration1, concentration0) + + # Gradient w.r.t. value - use safe division and set to 0 at edge cases + safe_value = jnp.where(value == 0.0, 1.0, value) + safe_one_minus = jnp.where(value == 1.0, 1.0, 1.0 - value) + grad_value = (concentration1 - 1.0) / safe_value - ( + concentration0 - 1.0 + ) / safe_one_minus + grad_value = jnp.where( + ((value == 0.0) & (concentration1 == 1.0)) + | ((value == 1.0) & (concentration0 == 1.0)), + 0.0, + grad_value, + ) + + # Gradients w.r.t. concentration parameters - use safe log (0 instead of -inf) + digamma_sum = digamma(concentration1 + concentration0) + grad_concentration1 = ( + jnp.where(value == 0.0, 0.0, jnp.log(value)) + - digamma(concentration1) + + digamma_sum + ) + grad_concentration0 = ( + jnp.where(value == 1.0, 0.0, jnp.log(1.0 - value)) + - digamma(concentration0) + + digamma_sum + ) + + # Chain rule + tangent_out = ( + grad_value * value_dot + + grad_concentration1 * concentration1_dot + + grad_concentration0 * concentration0_dot + ) + + return primal_out, tangent_out + + class Beta(Distribution): arg_constraints = { "concentration1": constraints.positive, @@ -216,18 +286,9 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: - # Handle edge cases where concentration1=1 and value=0, or concentration0=1 and value=1 - # These cases would result in nan due to log(0) * 0 in the Dirichlet computation - log_prob = jnp.where( - (value == 0.0) & (self.concentration1 == 1.0), - jnp.log(self.concentration0), - jnp.where( - (value == 1.0) & (self.concentration0 == 1.0), - jnp.log(self.concentration1), - self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1)), - ), - ) - return log_prob + # Compute Beta log_prob directly using the formula with custom gradients + # to handle edge cases where concentration=1 and value is at boundary + return _beta_log_prob(value, self.concentration1, self.concentration0) @property def mean(self) -> ArrayLike: diff --git a/test/test_distributions.py b/test/test_distributions.py index 531913bbd..c70b124b1 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4207,3 +4207,73 @@ def test_beta_logprob_boundary_non_edge_cases(): assert jnp.isneginf(beta_dist3.log_prob(0.0)) assert jnp.isneginf(beta_dist4.log_prob(1.0)) + + +@pytest.mark.parametrize( + argnames="concentration1,concentration0,value,grad_param,grad_value", + argvalues=[ + (1.0, 8.0, 0.0, "value", 0.0), + (8.0, 1.0, 1.0, "value", 1.0), + (1.0, 8.0, 0.0, "concentration1", 1.0), + (1.0, 8.0, 0.0, "concentration0", 8.0), + (8.0, 1.0, 1.0, "concentration1", 8.0), + (8.0, 1.0, 1.0, "concentration0", 1.0), + ], + ids=[ + "Beta(1,8) at x=0", + "Beta(8,1) at x=1", + "Beta(1,8) at concentration1=1", + "Beta(1,8) at concentration0=8", + "Beta(8,1) at concentration1=8", + "Beta(8,1) at concentration0=1", + ], +) +def test_beta_gradient_edge_cases_single_param( + concentration1, concentration0, value, grad_param, grad_value +): + """Test that gradients w.r.t. individual parameters are finite at edge cases.""" + if grad_param == "value": + + def log_prob_fn(x): + return dist.Beta(concentration1, concentration0).log_prob(x) + + grad = jax.grad(log_prob_fn)(value) + elif grad_param == "concentration1": + + def log_prob_fn(c1): + return dist.Beta(c1, concentration0).log_prob(value) + + grad = jax.grad(log_prob_fn)(grad_value) + else: # concentration0 + + def log_prob_fn(c0): + return dist.Beta(concentration1, c0).log_prob(value) + + grad = jax.grad(log_prob_fn)(grad_value) + + assert jnp.isfinite(grad), ( + f"Gradient w.r.t. {grad_param} for Beta({concentration1},{concentration0}) " + f"at x={value} should be finite" + ) + + +@pytest.mark.parametrize( + argnames="concentration1,concentration0,value", + argvalues=[ + (1.0, 8.0, 0.0), + (8.0, 1.0, 1.0), + ], + ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"], +) +def test_beta_gradient_edge_cases_all_params(concentration1, concentration0, value): + """Test that all gradients are finite when computed simultaneously at edge cases.""" + + def log_prob_fn(params): + c1, c0, v = params + return dist.Beta(c1, c0).log_prob(v) + + grads = jax.grad(log_prob_fn)(jnp.array([concentration1, concentration0, value])) + assert jnp.all(jnp.isfinite(grads)), ( + f"All gradients for Beta({concentration1},{concentration0}) at x={value} " + f"should be finite" + ) From 19f8be70669b36ea0444bedd766507374c547044 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sat, 1 Nov 2025 19:49:16 +0100 Subject: [PATCH 3/6] simplify tests --- test/test_distributions.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index c70b124b1..c153d34f9 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4154,22 +4154,25 @@ def test_censored_sample_shape(): assert samples.shape == expected_shape -def test_beta_logprob_edge_case_concentration1_one(): - """Test Beta(1, β) at x=0 should give finite log probability.""" - beta_dist = dist.Beta(1.0, 8.0) - log_prob_at_zero = beta_dist.log_prob(0.0) - - assert not jnp.isnan(log_prob_at_zero), "Beta(1,8).log_prob(0) should not be NaN" - assert jnp.isfinite(log_prob_at_zero), "Beta(1,8).log_prob(0) should be finite" - - -def test_beta_logprob_edge_case_concentration0_one(): - """Test Beta(α, 1) at x=1 should give finite log probability.""" - beta_dist2 = dist.Beta(8.0, 1.0) - log_prob_at_one = beta_dist2.log_prob(1.0) +@pytest.mark.parametrize( + argnames="concentration1,concentration0,value", + argvalues=[ + (1.0, 8.0, 0.0), + (8.0, 1.0, 1.0), + ], + ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"], +) +def test_beta_logprob_edge_cases(concentration1, concentration0, value): + """Test Beta distribution with concentration=1 gives finite log probability at boundary.""" + beta_dist = dist.Beta(concentration1, concentration0) + log_prob = beta_dist.log_prob(value) - assert not jnp.isnan(log_prob_at_one), "Beta(8,1).log_prob(1) should not be NaN" - assert jnp.isfinite(log_prob_at_one), "Beta(8,1).log_prob(1) should be finite" + assert not jnp.isnan(log_prob), ( + f"Beta({concentration1},{concentration0}).log_prob({value}) should not be NaN" + ) + assert jnp.isfinite(log_prob), ( + f"Beta({concentration1},{concentration0}).log_prob({value}) should be finite" + ) def test_beta_logprob_edge_case_consistency_small_values(): From 8e402b21e5c1a8ccd9f9683a7cfbe0715c4d2ca3 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sat, 1 Nov 2025 20:17:57 +0100 Subject: [PATCH 4/6] another approach --- numpyro/distributions/continuous.py | 66 +++++++++++++++-------------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index bfbed5fe8..1a3285147 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -197,54 +197,58 @@ def _beta_log_prob(value, concentration1, concentration0): @_beta_log_prob.defjvp def _beta_log_prob_jvp(primals, tangents): - """ - Define custom JVP (Jacobian-vector product) for Beta log_prob. - - For Beta(α, β), the derivatives are: - - d/dx log_prob = (α - 1) / x - (β - 1) / (1 - x) - - d/dα log_prob = log(x) - ψ(α) + ψ(α + β) - - d/dβ log_prob = log(1 - x) - ψ(β) + ψ(α + β) - - where ψ is the digamma function. Edge cases (α=1, x=0) or (β=1, x=1) are handled - by setting gradients to finite values using safe operations. - """ + """Custom JVP for Beta log_prob handling edge cases at boundaries.""" value, concentration1, concentration0 = primals value_dot, concentration1_dot, concentration0_dot = tangents - primal_out = _beta_log_prob(value, concentration1, concentration0) - # Gradient w.r.t. value - use safe division and set to 0 at edge cases - safe_value = jnp.where(value == 0.0, 1.0, value) + # Gradient w.r.t. value - safe division, zero at edge cases + safe_val = jnp.where(value == 0.0, 1.0, value) safe_one_minus = jnp.where(value == 1.0, 1.0, 1.0 - value) - grad_value = (concentration1 - 1.0) / safe_value - ( + grad_val = (concentration1 - 1.0) / safe_val - ( concentration0 - 1.0 ) / safe_one_minus - grad_value = jnp.where( + grad_val = jnp.where( ((value == 0.0) & (concentration1 == 1.0)) | ((value == 1.0) & (concentration0 == 1.0)), 0.0, - grad_value, + grad_val, ) - # Gradients w.r.t. concentration parameters - use safe log (0 instead of -inf) - digamma_sum = digamma(concentration1 + concentration0) - grad_concentration1 = ( - jnp.where(value == 0.0, 0.0, jnp.log(value)) - - digamma(concentration1) - + digamma_sum + # Gradients w.r.t. concentrations - safe log (0 instead of -inf) + dsum = digamma(concentration1 + concentration0) + grad_c1 = ( + jnp.where(value == 0.0, 0.0, jnp.log(value)) - digamma(concentration1) + dsum ) - grad_concentration0 = ( + grad_c0 = ( jnp.where(value == 1.0, 0.0, jnp.log(1.0 - value)) - digamma(concentration0) - + digamma_sum + + dsum ) - # Chain rule - tangent_out = ( - grad_value * value_dot - + grad_concentration1 * concentration1_dot - + grad_concentration0 * concentration0_dot - ) + # Build tangent output - handle Zero tangents properly + from jax.interpreters import ad + + def is_tangent_active(tangent): + """Check if tangent is active (not Zero or float0).""" + if isinstance(tangent, ad.Zero): + return False + # Check for float0 dtype (float0 has itemsize 0) + if ( + hasattr(tangent, "dtype") + and hasattr(tangent.dtype, "itemsize") + and tangent.dtype.itemsize == 0 + ): + return False + return True + + tangent_out = 0.0 + if is_tangent_active(value_dot): + tangent_out = tangent_out + grad_val * value_dot + if is_tangent_active(concentration1_dot): + tangent_out = tangent_out + grad_c1 * concentration1_dot + if is_tangent_active(concentration0_dot): + tangent_out = tangent_out + grad_c0 * concentration0_dot return primal_out, tangent_out From c965ed23bdcf3928a7264ae48d9e0a36f73249b2 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sat, 1 Nov 2025 20:33:44 +0100 Subject: [PATCH 5/6] clean up merge --- test/test_distributions.py | 196 ++++++++++++++++++------------------- 1 file changed, 98 insertions(+), 98 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index ac3e3a4c4..ace1839b0 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4390,6 +4390,104 @@ def test_censored_sample_shape(): assert samples.shape == expected_shape +@pytest.mark.parametrize( + "left_censored, right_censored, lower, upper, censoring_type", + [ + # left censored examples + (1.0, 0.0, 0.001, 1.0, "left"), + (1.0, 0.0, 0.001, 0.001, "left"), + # right censored examples + (0.0, 1.0, 0.001, 1.0, "right"), + (0.0, 1.0, 0.001, 0.001, "right"), + # interval censored example + (0.0, 0.0, 0.001, 1.0, "interval"), + # # doubly censored example + (1.0, 1.0, 0.001, 1.0, "double"), + # exact example + (0.0, 0.0, 0.001, 0.001, "exact"), + ], +) +def test_interval_censored_masks( + left_censored, right_censored, lower, upper, censoring_type +): + base_dist = dist.HalfNormal() + censored_dist = dist.IntervalCensoredDistribution( + base_dist, + left_censored, + right_censored, + validate_args=True, + ) + value = jnp.array([[lower, upper]]) + m_left, m_right, m_interval, m_doubly, m_exact = censored_dist._get_censoring_masks( + value + ) + # assert that always exactly one mask is true + assert m_left + m_right + m_interval + m_doubly + m_exact == 1 + + if censoring_type == "left": + assert m_left + elif censoring_type == "right": + assert m_right + elif censoring_type == "interval": + assert m_interval + elif censoring_type == "double": + assert m_doubly + elif censoring_type == "exact": + assert m_exact + + +@pytest.mark.parametrize( + "left_censored, right_censored, lower, upper, should_raise", + [ + # left censored examples + (1.0, 0.0, 0.001, 1.0, False), + (1.0, 0.0, 0.001, -1.0, True), + (1.0, 0.0, -0.001, 1.0, False), + (1.0, 0.0, -jnp.inf, 1.0, False), + (1.0, 0.0, jnp.nan, 1.0, False), + # right censored examples + (0.0, 1.0, 0.001, 1.0, False), + (0.0, 1.0, 0.001, -1.0, False), + (0.0, 1.0, -1.0, 0.001, True), + (0.0, 1.0, 1.0, -jnp.inf, False), + (0.0, 1.0, 1.0, jnp.nan, False), + # interval, doubly, exact examples: both bounds valid + # interval censored examples + (0.0, 0.0, 0.001, 1.0, False), + (0.0, 0.0, -0.001, 1.0, True), + (0.0, 0.0, 0.001, -1.0, True), + # doubly censored examples + (1.0, 1.0, 0.001, 1.0, False), + (1.0, 1.0, -0.001, 1.0, True), + (1.0, 1.0, 0.001, -1.0, True), + # exact examples + (0.0, 0.0, 0.001, 0.001, False), + (0.0, 0.0, -0.001, -0.001, True), + # interval and doubly censored, upper should be >= lower + (0.0, 0.0, 0.001, 0.002, False), + (1.0, 1.0, 0.001, 0.002, False), + (0.0, 0.0, 0.002, 0.001, True), + (1.0, 1.0, 0.002, 0.001, True), + ], +) +def test_interval_censored_validate_sample( + left_censored, right_censored, lower, upper, should_raise +): + base_dist = dist.HalfNormal() + censored_dist = dist.IntervalCensoredDistribution( + base_dist, + left_censored, + right_censored, + validate_args=True, + ) + value = jnp.array([[lower, upper]]) + if should_raise: + with pytest.raises(UserWarning): + censored_dist.log_prob(value) + else: + censored_dist.log_prob(value) # Should not raise + + @pytest.mark.parametrize( argnames="concentration1,concentration0,value", argvalues=[ @@ -4516,101 +4614,3 @@ def log_prob_fn(params): f"All gradients for Beta({concentration1},{concentration0}) at x={value} " f"should be finite" ) - -@pytest.mark.parametrize( - "left_censored, right_censored, lower, upper, censoring_type", - [ - # left censored examples - (1.0, 0.0, 0.001, 1.0, "left"), - (1.0, 0.0, 0.001, 0.001, "left"), - # right censored examples - (0.0, 1.0, 0.001, 1.0, "right"), - (0.0, 1.0, 0.001, 0.001, "right"), - # interval censored example - (0.0, 0.0, 0.001, 1.0, "interval"), - # # doubly censored example - (1.0, 1.0, 0.001, 1.0, "double"), - # exact example - (0.0, 0.0, 0.001, 0.001, "exact"), - ], -) -def test_interval_censored_masks( - left_censored, right_censored, lower, upper, censoring_type -): - base_dist = dist.HalfNormal() - censored_dist = dist.IntervalCensoredDistribution( - base_dist, - left_censored, - right_censored, - validate_args=True, - ) - value = jnp.array([[lower, upper]]) - m_left, m_right, m_interval, m_doubly, m_exact = censored_dist._get_censoring_masks( - value - ) - # assert that always exactly one mask is true - assert m_left + m_right + m_interval + m_doubly + m_exact == 1 - - if censoring_type == "left": - assert m_left - elif censoring_type == "right": - assert m_right - elif censoring_type == "interval": - assert m_interval - elif censoring_type == "double": - assert m_doubly - elif censoring_type == "exact": - assert m_exact - - -@pytest.mark.parametrize( - "left_censored, right_censored, lower, upper, should_raise", - [ - # left censored examples - (1.0, 0.0, 0.001, 1.0, False), - (1.0, 0.0, 0.001, -1.0, True), - (1.0, 0.0, -0.001, 1.0, False), - (1.0, 0.0, -jnp.inf, 1.0, False), - (1.0, 0.0, jnp.nan, 1.0, False), - # right censored examples - (0.0, 1.0, 0.001, 1.0, False), - (0.0, 1.0, 0.001, -1.0, False), - (0.0, 1.0, -1.0, 0.001, True), - (0.0, 1.0, 1.0, -jnp.inf, False), - (0.0, 1.0, 1.0, jnp.nan, False), - # interval, doubly, exact examples: both bounds valid - # interval censored examples - (0.0, 0.0, 0.001, 1.0, False), - (0.0, 0.0, -0.001, 1.0, True), - (0.0, 0.0, 0.001, -1.0, True), - # doubly censored examples - (1.0, 1.0, 0.001, 1.0, False), - (1.0, 1.0, -0.001, 1.0, True), - (1.0, 1.0, 0.001, -1.0, True), - # exact examples - (0.0, 0.0, 0.001, 0.001, False), - (0.0, 0.0, -0.001, -0.001, True), - # interval and doubly censored, upper should be >= lower - (0.0, 0.0, 0.001, 0.002, False), - (1.0, 1.0, 0.001, 0.002, False), - (0.0, 0.0, 0.002, 0.001, True), - (1.0, 1.0, 0.002, 0.001, True), - ], -) -def test_interval_censored_validate_sample( - left_censored, right_censored, lower, upper, should_raise -): - base_dist = dist.HalfNormal() - censored_dist = dist.IntervalCensoredDistribution( - base_dist, - left_censored, - right_censored, - validate_args=True, - ) - value = jnp.array([[lower, upper]]) - if should_raise: - with pytest.raises(UserWarning): - censored_dist.log_prob(value) - else: - censored_dist.log_prob(value) # Should not raise - From 1a9eae1113e186f147292e7060ccbdf6846da41e Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 17 Nov 2025 12:03:42 +0100 Subject: [PATCH 6/6] simplyfy with double where trick --- numpyro/distributions/continuous.py | 112 +++++++++------------------- 1 file changed, 35 insertions(+), 77 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 1a3285147..e2b3fb0f9 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -179,80 +179,6 @@ def icdf(self, value: ArrayLike) -> ArrayLike: ) -@jax.custom_jvp -def _beta_log_prob(value, concentration1, concentration0): - """ - Compute Beta log probability with custom gradients to handle edge cases. - - When concentration1=1 and value=0, or concentration0=1 and value=1, - the standard formula involves log(0) * 0 which should be 0, but has - undefined gradients. We use custom_jvp to define proper gradients. - """ - return ( - xlogy(concentration1 - 1.0, value) - + xlogy(concentration0 - 1.0, 1.0 - value) - - betaln(concentration1, concentration0) - ) - - -@_beta_log_prob.defjvp -def _beta_log_prob_jvp(primals, tangents): - """Custom JVP for Beta log_prob handling edge cases at boundaries.""" - value, concentration1, concentration0 = primals - value_dot, concentration1_dot, concentration0_dot = tangents - primal_out = _beta_log_prob(value, concentration1, concentration0) - - # Gradient w.r.t. value - safe division, zero at edge cases - safe_val = jnp.where(value == 0.0, 1.0, value) - safe_one_minus = jnp.where(value == 1.0, 1.0, 1.0 - value) - grad_val = (concentration1 - 1.0) / safe_val - ( - concentration0 - 1.0 - ) / safe_one_minus - grad_val = jnp.where( - ((value == 0.0) & (concentration1 == 1.0)) - | ((value == 1.0) & (concentration0 == 1.0)), - 0.0, - grad_val, - ) - - # Gradients w.r.t. concentrations - safe log (0 instead of -inf) - dsum = digamma(concentration1 + concentration0) - grad_c1 = ( - jnp.where(value == 0.0, 0.0, jnp.log(value)) - digamma(concentration1) + dsum - ) - grad_c0 = ( - jnp.where(value == 1.0, 0.0, jnp.log(1.0 - value)) - - digamma(concentration0) - + dsum - ) - - # Build tangent output - handle Zero tangents properly - from jax.interpreters import ad - - def is_tangent_active(tangent): - """Check if tangent is active (not Zero or float0).""" - if isinstance(tangent, ad.Zero): - return False - # Check for float0 dtype (float0 has itemsize 0) - if ( - hasattr(tangent, "dtype") - and hasattr(tangent.dtype, "itemsize") - and tangent.dtype.itemsize == 0 - ): - return False - return True - - tangent_out = 0.0 - if is_tangent_active(value_dot): - tangent_out = tangent_out + grad_val * value_dot - if is_tangent_active(concentration1_dot): - tangent_out = tangent_out + grad_c1 * concentration1_dot - if is_tangent_active(concentration0_dot): - tangent_out = tangent_out + grad_c0 * concentration0_dot - - return primal_out, tangent_out - - class Beta(Distribution): arg_constraints = { "concentration1": constraints.positive, @@ -290,9 +216,41 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: - # Compute Beta log_prob directly using the formula with custom gradients - # to handle edge cases where concentration=1 and value is at boundary - return _beta_log_prob(value, self.concentration1, self.concentration0) + # Use double-where trick to avoid NaN gradients at boundary conditions + # when concentration parameters equal 1 (following TF Probability approach). + # Reference: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf + # + # The key insight is to mask extreme values BEFORE computation, so gradients + # flow through the safe path. The forward pass automatically gets the right + # answer because xlogy(0, 0) = 0. + + # Step 1: Identify boundary values (0 or 1) + is_boundary = (value == 0.0) | (value == 1.0) + + # Step 2: Inner where - mask boundary values to safe canonical value (0.5) + # This ensures log(0) never appears in the gradient computation path + safe_value = jnp.where(is_boundary, 0.5, value) + + # Step 3: Compute log_prob with safe values (gradients flow through here) + safe_log_prob = ( + xlogy(self.concentration1 - 1.0, safe_value) + + xlogy(self.concentration0 - 1.0, 1.0 - safe_value) + - betaln(self.concentration1, self.concentration0) + ) + + # Step 4: Compute correct forward-pass value at boundaries + # Use stop_gradient to prevent gradients from flowing through this branch + # xlogy(0, 0) = 0 gives the correct value when concentration=1 at boundaries + boundary_log_prob = jax.lax.stop_gradient( + xlogy(self.concentration1 - 1.0, value) + + xlogy(self.concentration0 - 1.0, 1.0 - value) + - betaln(self.concentration1, self.concentration0) + ) + + # Step 5: Outer where - select boundary value at boundaries, safe value elsewhere + # Forward pass: uses boundary_log_prob at boundaries (correct value) + # Gradients: come from safe_log_prob (finite, since safe_value avoids log(0)) + return jnp.where(is_boundary, boundary_log_prob, safe_log_prob) @property def mean(self) -> ArrayLike: