Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ ignore-patterns=^\.#
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis). It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=pyrecest.backend
ignored-modules=pyrecest.backend, jax

# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
Expand Down
3 changes: 3 additions & 0 deletions pyrecest/_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def get_backend_name():
"randint",
"seed",
"uniform",
# For PyRecEst
"get_state",
"set_state",
],
"fft": [ # For PyRecEst
"rfft",
Expand Down
15 changes: 8 additions & 7 deletions pyrecest/_backend/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ def create_random_state(seed = 0):
def global_random_state():
return backend.jax_global_random_state


def set_global_random_state(state):
backend.jax_global_random_state = state

get_state = global_random_state
set_state = set_global_random_state

def get_state(**kwargs):
def _get_state(**kwargs):
has_state = 'state' in kwargs
state = kwargs.pop('state', backend.jax_global_random_state)
return state, has_state, kwargs
Expand All @@ -51,7 +52,7 @@ def _rand(state, size, *args, **kwargs):

def rand(size, *args, **kwargs):
size = size if hasattr(size, "__iter__") else (size,)
state, has_state, kwargs = get_state(**kwargs)
state, has_state, kwargs = _get_state(**kwargs)
state, res = _rand(state, size, *args, **kwargs)
return set_state_return(has_state, state, res)

Expand All @@ -66,7 +67,7 @@ def _randint(state, size, *args, **kwargs):

def randint(size, *args, **kwargs):
size = size if hasattr(size, "__iter__") else (size,)
state, has_state, kwargs = get_state(**kwargs)
state, has_state, kwargs = _get_state(**kwargs)
state, res = _randint(state, size, *args, **kwargs)
return set_state_return(has_state, state, res)

Expand All @@ -78,7 +79,7 @@ def _normal(state, size, *args, **kwargs):

def normal(size, *args, **kwargs):
size = size if hasattr(size, "__iter__") else (size,)
state, has_state, kwargs = get_state(**kwargs)
state, has_state, kwargs = _get_state(**kwargs)

# Check and remove 'mean' and 'cov' from kwargs
mean = kwargs.pop('mean', None)
Expand All @@ -102,7 +103,7 @@ def _choice(state, a, n, *args, **kwargs):


def choice(a, n, *args, **kwargs):
state, has_state, kwargs = get_state(**kwargs)
state, has_state, kwargs = _get_state(**kwargs)
state, res = _choice(state, a, n, *args, **kwargs)
return set_state_return(has_state, state, res)

Expand All @@ -114,7 +115,7 @@ def _multivariate_normal(state, size, *args, **kwargs):

def multivariate_normal(size, *args, **kwargs):
size = size if hasattr(size, "__iter__") else (size,)
state, has_state, kwargs = get_state(**kwargs)
state, has_state, kwargs = _get_state(**kwargs)
state, res = _multivariate_normal(state, size, *args, **kwargs)
return set_state_return(has_state, state, res)

Expand Down
1 change: 1 addition & 0 deletions pyrecest/_backend/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
import numpy as _np
from numpy.random import default_rng as _default_rng
from numpy.random import randint, seed, multinomial
from numpy.random import set_state, get_state # For PyRecEst

from .._shared_numpy.random import choice, multivariate_normal, normal, rand, uniform
2 changes: 2 additions & 0 deletions pyrecest/_backend/pytorch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch as _torch
from torch import rand, randint
from torch import get_rng_state as get_state # For PyRecEst
from torch import set_rng_state as set_state # For PyRecEst
from torch.distributions.multivariate_normal import (
MultivariateNormal as _MultivariateNormal,
)
Expand Down
157 changes: 143 additions & 14 deletions pyrecest/distributions/abstract_manifold_specific_distribution.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Union
import inspect

import pyrecest.backend

# pylint: disable=no-name-in-module,no-member
# pylint: disable=no-name-in-module,no-member,redefined-builtin
from pyrecest.backend import empty, int32, int64, log, random, squeeze


Expand Down Expand Up @@ -64,13 +65,14 @@ def set_mode(self, _):
"""
raise NotImplementedError("set_mode is not implemented for this distribution")

# Need to use Union instead of | to support torch.dtype
# Need to use Union instead of | to support torch.dtype
def sample(self, n: Union[int, int32, int64]):
"""Obtain n samples from the distribution."""
return self.sample_metropolis_hastings(n)

# jscpd:ignore-start
# pylint: disable=too-many-positional-arguments
# pylint: disable=too-many-positional-arguments,too-many-locals
def sample_metropolis_hastings(
self,
n: Union[int, int32, int64],
Expand All @@ -81,30 +83,48 @@ def sample_metropolis_hastings(
):
# jscpd:ignore-end
"""Metropolis Hastings sampling algorithm."""
assert (
pyrecest.backend.__backend_name__ != "jax"
), "Not supported on this backend"
if pyrecest.backend.__backend_name__ == "jax":
# Get a key from your global JAX random state *outside* of lax.scan
import jax as _jax

key = random.get_state()
key, key_for_mh = _jax.random.split(key)
# Optionally update global state for future calls
random.set_state(key)

if proposal is None or start_point is None:
raise NotImplementedError(
"Default proposals and starting points should be set in inheriting classes."
)
_assert_proposal_supports_key(proposal)

samples, _ = sample_metropolis_hastings_jax(
key=key_for_mh,
log_pdf=self.ln_pdf,
proposal=proposal, # must be (key, x) -> x_prop for JAX
start_point=start_point,
n=int(n),
burn_in=int(burn_in),
skipping=int(skipping),
)
# You could optionally stash `key_out` somewhere if you want chain continuation.
return squeeze(samples)

# Non-JAX backends → your old NumPy/Torch code
if proposal is None or start_point is None:
raise NotImplementedError(
"Default proposals and starting points should be set in inheriting classes."
)

total_samples = burn_in + n * skipping
s = empty(
(
total_samples,
self.input_dim,
),
)
s = empty((total_samples, self.input_dim))
x = start_point
i = 0
pdfx = self.pdf(x)

while i < total_samples:
x_new = proposal(x)
assert (
x_new.shape == x.shape
), "Proposal must return a vector of same shape as input"
assert x_new.shape == x.shape, "Proposal must return a vector of same shape as input"
pdfx_new = self.pdf(x_new)
a = pdfx_new / pdfx
if a.item() > 1 or a.item() > random.rand(1):
Expand All @@ -115,3 +135,112 @@ def sample_metropolis_hastings(

relevant_samples = s[burn_in::skipping, :]
return squeeze(relevant_samples)

# pylint: disable=too-many-positional-arguments,too-many-locals,too-many-arguments
def sample_metropolis_hastings_jax(
key,
log_pdf, # function: x -> log p(x)
proposal, # function: (key, x) -> x_prop
start_point,
n: int,
burn_in: int = 10,
skipping: int = 5,
):
"""
Metropolis-Hastings sampler in JAX.

key: jax.random.PRNGKey
log_pdf: callable x -> log p(x)
proposal: callable (key, x) -> x_proposed
start_point: initial state (array)
n: number of samples to return (after burn-in and thinning)
"""
import jax.numpy as _jnp
from jax import lax as _lax
from jax import random as _random


start_point = _jnp.asarray(start_point)
total_steps = burn_in + n * skipping

def one_step(carry, _):
key, x, log_px = carry
key, key_prop, key_u = _random.split(key, 3)

# Propose new state
x_prop = proposal(key_prop, x)
log_px_prop = log_pdf(x_prop)

# log_alpha = log p(x_prop) - log p(x)
log_alpha = log_px_prop - log_px

# Draw u ~ Uniform(0, 1)
u = _random.uniform(key_u, shape=())
log_u = _jnp.log(u)

# Accept if log u < min(0, log_alpha)
# (equivalent to u < exp(min(0, log_alpha)))
log_alpha_capped = _jnp.minimum(0.0, log_alpha)
accept = log_u < log_alpha_capped # scalar bool

# Branch without Python if
x_new = _jnp.where(accept, x_prop, x)
log_px_new = _jnp.where(accept, log_px_prop, log_px)

return (key, x_new, log_px_new), x_new

init_carry = (key, start_point, log_pdf(start_point))
(key_out, _, _), chain = _lax.scan(
one_step,
init_carry,
xs=None,
length=total_steps,
)

samples = chain[burn_in::skipping]
return samples, key_out


def _assert_proposal_supports_key(proposal: Callable):
"""
Check that `proposal` can be called as proposal(key, x).

Raises a TypeError with a helpful message if this is not the case.
"""
# Unwrap jitted / partial / decorated functions if possible
func = proposal
while hasattr(func, "__wrapped__"):
func = func.__wrapped__

try:
sig = inspect.signature(func)
except (TypeError, ValueError):
# Can't introspect (e.g. builtins); fall back to a generic error
raise TypeError(
"For the JAX backend, `proposal` must accept (key, x) as arguments, "
"but its signature could not be inspected."
) from None

params = list(sig.parameters.values())

# Count positional(-or-keyword) parameters
num_positional = sum(
p.kind in (inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD)
for p in params
)
has_var_positional = any(
p.kind == inspect.Parameter.VAR_POSITIONAL
for p in params
)

if has_var_positional or num_positional >= 2:
# Looks compatible with (key, x)
return

raise TypeError(
"For the JAX backend, `proposal` must accept `(key, x)` as arguments.\n"
f"Got signature: {sig}\n"
"Hint: change your proposal from `def proposal(x): ...` to\n"
"`def proposal(key, x): ...` and use `jax.random` with the passed key."
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
random,
vstack,
zeros,
sqrt,
cos,
sin,
linalg,
stack,
)
from scipy.optimize import minimize

Expand Down Expand Up @@ -60,9 +65,44 @@ def sample_metropolis_hastings(
from .hyperhemispherical_uniform_distribution import (
HyperhemisphericalUniformDistribution,
)
if pyrecest.backend.__backend_name__ in ("numpy", "pytorch"):
def proposal(_):
return HyperhemisphericalUniformDistribution(self.dim).sample(1)
else:
# JAX backend: proposal(key, x) -> x_prop
import jax as _jax
import jax.numpy as _jnp

def proposal(key, _):
"""JAX independence proposal: uniform on upper hemisphere."""
if self.dim == 2:
# Explicit S² sampling
key, key_phi = _jax.random.split(key)
key, key_sz = _jax.random.split(key)

phi = 2.0 * _jnp.pi * _jax.random.uniform(key_phi, shape=(1,))
sz = 2.0 * _jax.random.uniform(key_sz, shape=(1,)) - 1.0
r = _jnp.sqrt(1.0 - sz**2)

# Shape (1, 3)
s = _jnp.stack(
[r * _jnp.cos(phi), r * _jnp.sin(phi), sz],
axis=1,
)
else:
# General S^d: sample N(0, I) in R^{d+1} and normalize
key, subkey = _jax.random.split(key)
samples_unnorm = _jax.random.normal(subkey, shape=(1, self.dim + 1))
norms = _jnp.linalg.norm(samples_unnorm, axis=1, keepdims=True)
s = samples_unnorm / norms

# Project to upper hemisphere: last coordinate >= 0
# s shape: (1, dim+1); last coord is s[..., -1:]
sign = _jnp.where(s[..., -1:] < 0.0, -1.0, 1.0)
s = sign * s

return s

def proposal(_):
return HyperhemisphericalUniformDistribution(self.dim).sample(1)

if start_point is None:
start_point = HyperhemisphericalUniformDistribution(self.dim).sample(1)
Expand Down
Loading
Loading