Skip to content

Commit ebd836e

Browse files
Add power_spectral_density to RatQuad covariance kernel for HSGP support (#7973)
* Add power_spectral_density to RatQuad covariance kernel for HSGP support * Added derivation info to docstring * Add test for covariance eigenvalue equivalence * Update pymc/gp/cov.py Co-authored-by: Jesse Grabowski <[email protected]> * Better eigenvalue varname Co-authored-by: Jesse Grabowski <[email protected]> * Replace eigenvalue decomp with Rayleigh quotients * Remove test_psd from ratquad --------- Co-authored-by: Jesse Grabowski <[email protected]>
1 parent 87f80f9 commit ebd836e

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

pymc/gp/cov.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,49 @@ def full_from_distance(self, dist: TensorLike, squared: bool = False) -> TensorV
614614
-1.0 * self.alpha,
615615
)
616616

617+
def power_spectral_density(self, omega: TensorLike) -> TensorVariable:
618+
r"""
619+
Power spectral density for the Rational Quadratic kernel.
620+
621+
.. math::
622+
S(\boldsymbol\omega) = \frac{2 (2\pi\alpha)^{D/2} \prod_{i=1}^D \ell_i}{\Gamma(\alpha)}
623+
\left(\frac{z}{2}\right)^{\nu}
624+
K_{\nu}(z)
625+
where :math:`z = \sqrt{2\alpha} \sqrt{\sum \ell_i^2 \omega_i^2}` and :math:`\nu = \alpha - D/2`.
626+
627+
Derivation
628+
----------
629+
The Rational Quadratic kernel can be expressed as a scale mixture of Squared Exponential kernels:
630+
631+
.. math::
632+
k_{RQ}(r) = \int_0^\infty k_{SE}(r; \lambda) p(\lambda) d\lambda
633+
634+
where :math:`k_{SE}(r; \lambda) = \exp\left(-\frac{\lambda r^2}{2}\right)` and the mixing distribution
635+
on the precision parameter :math:`\lambda` is :math:`\lambda \sim \text{Gamma}(\alpha, \beta)`
636+
with rate parameter :math:`\beta = \alpha \ell^2`.
637+
638+
By the linearity of the Fourier transform, the PSD of the Rational Quadratic kernel is the expectation
639+
of the PSD of the Squared Exponential kernel with respect to the mixing distribution:
640+
641+
.. math::
642+
S_{RQ}(\omega) = \int_0^\infty S_{SE}(\omega; \lambda) p(\lambda) d\lambda
643+
644+
Substituting the known PSD of the Squared Exponential kernel and evaluating the integral yields
645+
the expression involving the modified Bessel function of the second kind, :math:`K_{\nu}(z)`.
646+
"""
647+
ls = pt.ones(self.n_dims) * self.ls
648+
alpha = self.alpha
649+
D = self.n_dims
650+
nu = alpha - D / 2.0
651+
652+
z = pt.sqrt(2 * alpha) * pt.sqrt(pt.dot(pt.square(omega), pt.square(ls)))
653+
coeff = 2.0 * pt.power(2.0 * np.pi * alpha, D / 2.0) * pt.prod(ls) / pt.gamma(alpha)
654+
655+
# Handle singularity at z=0
656+
term_z = pt.switch(pt.eq(z, 0), pt.gamma(nu) / 2.0, pt.power(z / 2.0, nu) * pt.kv(nu, z))
657+
658+
return coeff * term_z
659+
617660

618661
class Matern52(Stationary):
619662
r"""

tests/gp/test_cov.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,35 @@ def test_1d(self):
533533
Kd = cov(X, diag=True).eval()
534534
npt.assert_allclose(np.diag(K), Kd, atol=1e-5)
535535

536+
def test_psd_eigenvalues(self):
537+
"""Test PSD implementation using Rayleigh quotients."""
538+
alpha = 1.5
539+
ls = 0.1
540+
N = 1000
541+
L = 10.0
542+
dx = L / N
543+
X = np.linspace(0, L, N)[:, None]
544+
545+
with pm.Model():
546+
cov = pm.gp.cov.RatQuad(1, alpha=alpha, ls=ls)
547+
548+
K = cov(X).eval()
549+
550+
freqs = np.fft.fftfreq(N, d=dx)
551+
omegas = 2 * np.pi * freqs
552+
553+
j = np.arange(N)
554+
modes = np.exp(2j * np.pi * np.outer(np.arange(N), j) / N)
555+
numerator = np.diag(modes @ K @ modes.conj().T).real
556+
rayleigh_quotient = numerator / N
557+
558+
psd = cov.power_spectral_density(omegas[:, None]).eval()
559+
psd_scaled = psd.flatten() / dx
560+
561+
# Trim boundaries where numerical error concentrates
562+
trim = N // 10
563+
npt.assert_allclose(psd_scaled[trim:-trim], rayleigh_quotient[trim:-trim], atol=1e-2)
564+
536565

537566
class TestExponential:
538567
def test_1d(self):

0 commit comments

Comments
 (0)