From c75ebe6f7ca8edf2df86c507e99f81d437eeff63 Mon Sep 17 00:00:00 2001 From: Florian Pfaff Date: Sun, 17 Sep 2023 11:52:34 +0200 Subject: [PATCH] Added SE3LinVelCartProdStackedDistribution --- ..._lin_vel_cart_prod_stacked_distribution.py | 28 ++++++++++++++++ ..._lin_vel_cart_prod_stacked_distribution.py | 32 +++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 pyrecest/distributions/cart_prod/se3_lin_vel_cart_prod_stacked_distribution.py create mode 100644 pyrecest/tests/distributions/test_se3_lin_vel_cart_prod_stacked_distribution.py diff --git a/pyrecest/distributions/cart_prod/se3_lin_vel_cart_prod_stacked_distribution.py b/pyrecest/distributions/cart_prod/se3_lin_vel_cart_prod_stacked_distribution.py new file mode 100644 index 000000000..36ae4fb5a --- /dev/null +++ b/pyrecest/distributions/cart_prod/se3_lin_vel_cart_prod_stacked_distribution.py @@ -0,0 +1,28 @@ +from .cart_prod.cart_prod_stacked_distribution import CartProdStackedDistribution +from .abstract_se3_lin_vel_distribution import AbstractSE3LinVelDistribution +from .hypersphere_subset.abstract_hyperhemispherical_distribution import AbstractHyperhemisphericalDistribution +from .nonperiodic.abstract_linear_distribution import AbstractLinearDistribution +import numpy as np + +class SE3LinVelCartProdStackedDistribution(CartProdStackedDistribution, AbstractSE3LinVelDistribution): + def __init__(self, dists): + assert len(dists) == 2, "There must be exactly 2 distributions in dists" + assert dists[0].dim == 4, "The first distribution must have 4 dimensions" + assert isinstance(dists[0], AbstractHyperhemisphericalDistribution), "The first distribution must be an instance of AbstractHyperhemisphericalDistribution" + assert dists[1].dim == 6, "The second distribution must have 6 dimensions" + assert isinstance(dists[1], AbstractLinearDistribution), "The second distribution must be an instance of AbstractLinearDistribution" + + super().__init__(dists) + self.boundD = dists[0].dim + self.linD = dists[1].dim + self.periodicManifoldType = "hyperhemisphere" + + def marginalize_linear(self): + return self.dists[0] + + def marginalize_periodic(self): + return self.dists[1] + + def get_manifold_size(self): + return np.inf + diff --git a/pyrecest/tests/distributions/test_se3_lin_vel_cart_prod_stacked_distribution.py b/pyrecest/tests/distributions/test_se3_lin_vel_cart_prod_stacked_distribution.py new file mode 100644 index 000000000..8aea4b9b3 --- /dev/null +++ b/pyrecest/tests/distributions/test_se3_lin_vel_cart_prod_stacked_distribution.py @@ -0,0 +1,32 @@ +import unittest +import numpy as np +from pyrecest.distributions.cart_prod.se3_lin_vel_cart_prod_stacked_distribution import SE3LinVelCartProdStackedDistribution +from pyrecest.distributions import HyperhemisphericalUniformDistribution, GaussianDistribution, HyperhemisphericalWatsonDistribution + + +class TestSE3LinVelCartProdStackedDistribution(unittest.TestCase): + + def test_constructor(self): + SE3LinVelCartProdStackedDistribution([HyperhemisphericalUniformDistribution(4), GaussianDistribution(np.array([1, 2, 3, 4, 5, 6]), np.diag([3, 2, 1, 4, 3, 4]))]) + + def test_sampling(self): + cpd = SE3LinVelCartProdStackedDistribution([HyperhemisphericalUniformDistribution(4), GaussianDistribution(np.array([1, 2, 0, -2, -1, 3]), np.diag([3, 2, 3, 3, 4, 5]))]) + samples = cpd.sample(100) + self.assertEqual(samples.shape, (10, 100)) + + def test_pdf(self): + cpd = SE3LinVelCartProdStackedDistribution([HyperhemisphericalUniformDistribution(4), GaussianDistribution(np.array([1, 2, 0, -2, -1, 3]), np.diag([3, 2, 3, 3, 4, 5]))]) + self.assertEqual(cpd.pdf(np.random.randn(10, 100)).shape, (1, 100)) + + pdf_values = cpd.pdf(np.ones((10, 100))) + self.assertTrue(np.allclose(np.diff(pdf_values), np.zeros(99))) + + def test_mode(self): + watson = HyperhemisphericalWatsonDistribution(np.array([2, 1, 3, 1]) / np.linalg.norm(np.array([2, 1, 3, 1])), 2) + gaussian = GaussianDistribution(np.array([1, 2, 0, -2, -1, 3]), np.diag([3, 2, 3, 3, 4, 5])) + cpd = SE3LinVelCartProdStackedDistribution([watson, gaussian]) + self.assertTrue(np.allclose(cpd.mode(), np.hstack([watson.mode(), gaussian.mode()]))) + + +if __name__ == "__main__": + unittest.main()