Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from .cart_prod.cart_prod_stacked_distribution import CartProdStackedDistribution
from .abstract_se3_lin_vel_distribution import AbstractSE3LinVelDistribution
from .hypersphere_subset.abstract_hyperhemispherical_distribution import AbstractHyperhemisphericalDistribution
from .nonperiodic.abstract_linear_distribution import AbstractLinearDistribution
import numpy as np

class SE3LinVelCartProdStackedDistribution(CartProdStackedDistribution, AbstractSE3LinVelDistribution):
def __init__(self, dists):
assert len(dists) == 2, "There must be exactly 2 distributions in dists"
assert dists[0].dim == 4, "The first distribution must have 4 dimensions"
assert isinstance(dists[0], AbstractHyperhemisphericalDistribution), "The first distribution must be an instance of AbstractHyperhemisphericalDistribution"
assert dists[1].dim == 6, "The second distribution must have 6 dimensions"
assert isinstance(dists[1], AbstractLinearDistribution), "The second distribution must be an instance of AbstractLinearDistribution"

super().__init__(dists)
self.boundD = dists[0].dim
self.linD = dists[1].dim
self.periodicManifoldType = "hyperhemisphere"

def marginalize_linear(self):
return self.dists[0]

def marginalize_periodic(self):
return self.dists[1]

def get_manifold_size(self):
return np.inf

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest
import numpy as np
from pyrecest.distributions.cart_prod.se3_lin_vel_cart_prod_stacked_distribution import SE3LinVelCartProdStackedDistribution
from pyrecest.distributions import HyperhemisphericalUniformDistribution, GaussianDistribution, HyperhemisphericalWatsonDistribution


class TestSE3LinVelCartProdStackedDistribution(unittest.TestCase):

def test_constructor(self):
SE3LinVelCartProdStackedDistribution([HyperhemisphericalUniformDistribution(4), GaussianDistribution(np.array([1, 2, 3, 4, 5, 6]), np.diag([3, 2, 1, 4, 3, 4]))])

def test_sampling(self):
cpd = SE3LinVelCartProdStackedDistribution([HyperhemisphericalUniformDistribution(4), GaussianDistribution(np.array([1, 2, 0, -2, -1, 3]), np.diag([3, 2, 3, 3, 4, 5]))])
samples = cpd.sample(100)
self.assertEqual(samples.shape, (10, 100))

def test_pdf(self):
cpd = SE3LinVelCartProdStackedDistribution([HyperhemisphericalUniformDistribution(4), GaussianDistribution(np.array([1, 2, 0, -2, -1, 3]), np.diag([3, 2, 3, 3, 4, 5]))])
self.assertEqual(cpd.pdf(np.random.randn(10, 100)).shape, (1, 100))

pdf_values = cpd.pdf(np.ones((10, 100)))
self.assertTrue(np.allclose(np.diff(pdf_values), np.zeros(99)))

def test_mode(self):
watson = HyperhemisphericalWatsonDistribution(np.array([2, 1, 3, 1]) / np.linalg.norm(np.array([2, 1, 3, 1])), 2)
gaussian = GaussianDistribution(np.array([1, 2, 0, -2, -1, 3]), np.diag([3, 2, 3, 3, 4, 5]))
cpd = SE3LinVelCartProdStackedDistribution([watson, gaussian])
self.assertTrue(np.allclose(cpd.mode(), np.hstack([watson.mode(), gaussian.mode()])))


if __name__ == "__main__":
unittest.main()
Loading