-
Notifications
You must be signed in to change notification settings - Fork 0
WIP Ros transport #1034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP Ros transport #1034
Changes from all commits
8910756
27848b8
a9c9b08
a02ca85
8bd70fe
3881e4b
dc1f5d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
| from dimos.core.stream import In, Out, Stream, Transport | ||
| from dimos.protocol.pubsub.jpeg_shm import JpegSharedMemory | ||
| from dimos.protocol.pubsub.lcmpubsub import LCM, JpegLCM, PickleLCM, Topic as LCMTopic | ||
| from dimos.protocol.pubsub.rospubsub import ROS_AVAILABLE, DimosROS, ROSTopic | ||
| from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory, SharedMemory | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -125,6 +126,54 @@ def stop(self) -> None: | |
| self.lcm.stop() | ||
|
|
||
|
|
||
| class ROSTransport(PubSubTransport[T]): | ||
| """Transport that publishes/subscribes via ROS 2. | ||
|
|
||
| Automatically converts between dimos_lcm messages and ROS messages. | ||
|
|
||
| Usage: | ||
| from geometry_msgs.msg import Vector3 as ROSVector3 | ||
|
|
||
| ROSTransport("/my_topic", ROSVector3) | ||
| """ | ||
|
|
||
| _started: bool = False | ||
| _ros: DimosROS | None = None | ||
|
|
||
| def __init__(self, topic: str, ros_type: type, **kwargs) -> None: # type: ignore[no-untyped-def] | ||
| if not ROS_AVAILABLE: | ||
| raise ImportError("ROS not available. Install rclpy to use ROSTransport.") | ||
| super().__init__(ROSTopic(topic, ros_type)) | ||
| self._kwargs = kwargs | ||
|
|
||
| def _ensure_ros(self) -> DimosROS: | ||
| if self._ros is None: | ||
| self._ros = DimosROS(**self._kwargs) | ||
| return self._ros | ||
|
|
||
| def start(self) -> None: | ||
| if not self._started: | ||
| self._ensure_ros().start() | ||
| self._started = True | ||
|
|
||
| def stop(self) -> None: | ||
| if self._ros is not None: | ||
| self._ros.stop() | ||
|
|
||
| def __reduce__(self): # type: ignore[no-untyped-def] | ||
| return (ROSTransport, (self.topic.topic, self.topic.ros_type)) | ||
|
|
||
| def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def] | ||
| """Publish a dimos_lcm message to ROS (auto-converts to ROS message).""" | ||
| self.start() | ||
| self._ensure_ros().publish(self.topic, msg) | ||
|
|
||
| def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None: # type: ignore[assignment, override] | ||
| """Subscribe to ROS topic (auto-converts ROS messages to dimos_lcm).""" | ||
| self.start() | ||
| return self._ensure_ros().subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[return-value] | ||
|
|
||
|
Comment on lines
+129
to
+175
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Each Consider using a singleton pattern or shared ROS node instance across all Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||
|
|
||
| class pSHMTransport(PubSubTransport[T]): | ||
| _started: bool = False | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| #!/usr/bin/env python3 | ||
|
|
||
| # Copyright 2025-2026 Dimensional Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import threading | ||
| import time | ||
|
|
||
| import pytest | ||
|
|
||
| from dimos.protocol.pubsub.benchmark.testdata import testdata | ||
| from dimos.protocol.pubsub.benchmark.type import BenchmarkResult, BenchmarkResults | ||
|
|
||
| # Message sizes for throughput benchmarking (powers of 2 from 64B to 10MB) | ||
| MSG_SIZES = [ | ||
| 64, | ||
| 256, | ||
| 1024, | ||
| 4096, | ||
| 16384, | ||
| 65536, | ||
| 262144, | ||
| 524288, | ||
| 1048576, | ||
| 1048576 * 2, | ||
| 1048576 * 5, | ||
| 1048576 * 10, | ||
| ] | ||
|
|
||
| # Benchmark duration in seconds | ||
| BENCH_DURATION = 1.0 | ||
|
|
||
| # Max messages to send per test (prevents overwhelming slower transports) | ||
| MAX_MESSAGES = 5000 | ||
|
|
||
| # Max time to wait for in-flight messages after publishing stops | ||
| RECEIVE_TIMEOUT = 1.0 | ||
|
|
||
|
|
||
| def size_id(size: int) -> str: | ||
| """Convert byte size to human-readable string for test IDs.""" | ||
| if size >= 1048576: | ||
| return f"{size // 1048576}MB" | ||
| if size >= 1024: | ||
| return f"{size // 1024}KB" | ||
| return f"{size}B" | ||
|
|
||
|
|
||
| def pubsub_id(testcase) -> str: | ||
| """Extract pubsub implementation name from context manager function name.""" | ||
| name = testcase.pubsub_context.__name__ | ||
| # Convert e.g. "lcm_pubsub_channel" -> "LCM", "memory_pubsub_channel" -> "Memory" | ||
| prefix = name.replace("_pubsub_channel", "").replace("_", " ") | ||
| return prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "") | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def benchmark_results(): | ||
| """Module-scoped fixture to collect benchmark results.""" | ||
| results = BenchmarkResults() | ||
| yield results | ||
| results.print_summary() | ||
| results.print_heatmap() | ||
| results.print_bandwidth_heatmap() | ||
| results.print_latency_heatmap() | ||
|
|
||
|
|
||
| @pytest.mark.benchmark | ||
| @pytest.mark.parametrize("msg_size", MSG_SIZES, ids=[size_id(s) for s in MSG_SIZES]) | ||
| @pytest.mark.parametrize("pubsub_context, msggen", testdata, ids=[pubsub_id(t) for t in testdata]) | ||
| def test_throughput(pubsub_context, msggen, msg_size, benchmark_results): | ||
| """Measure throughput for publishing and receiving messages over a fixed duration.""" | ||
| with pubsub_context() as pubsub: | ||
| topic, msg = msggen(msg_size) | ||
| received_count = 0 | ||
| target_count = [0] # Use list to allow modification after publish loop | ||
| lock = threading.Lock() | ||
| all_received = threading.Event() | ||
|
|
||
| def callback(message, _topic): | ||
| nonlocal received_count | ||
| with lock: | ||
| received_count += 1 | ||
| if target_count[0] > 0 and received_count >= target_count[0]: | ||
| all_received.set() | ||
|
|
||
| # Subscribe | ||
| pubsub.subscribe(topic, callback) | ||
|
|
||
| # Warmup: give DDS/ROS time to establish connection | ||
| time.sleep(0.1) | ||
|
|
||
| # Set target so callback can signal when all received | ||
| target_count[0] = MAX_MESSAGES | ||
|
|
||
| # Publish messages until time limit, max messages, or all received | ||
| msgs_sent = 0 | ||
| start = time.perf_counter() | ||
| end_time = start + BENCH_DURATION | ||
|
|
||
| while time.perf_counter() < end_time and msgs_sent < MAX_MESSAGES: | ||
| pubsub.publish(topic, msg) | ||
| msgs_sent += 1 | ||
| # Check if all already received (fast transports) | ||
| if all_received.is_set(): | ||
| break | ||
|
|
||
| publish_end = time.perf_counter() | ||
| target_count[0] = msgs_sent # Update to actual sent count | ||
|
|
||
| # Check if already done, otherwise wait up to RECEIVE_TIMEOUT | ||
| with lock: | ||
| if received_count >= msgs_sent: | ||
| all_received.set() | ||
|
|
||
| if not all_received.is_set(): | ||
| all_received.wait(timeout=RECEIVE_TIMEOUT) | ||
| latency_end = time.perf_counter() | ||
|
|
||
| with lock: | ||
| final_received = received_count | ||
|
|
||
| # Latency: how long we waited after publishing for messages to arrive | ||
| # 0 = all arrived during publishing, 1000ms = hit timeout (loss occurred) | ||
| latency = latency_end - publish_end | ||
|
|
||
| # Record result (duration is publish time only for throughput calculation) | ||
| transport_name = pubsub_id(type("TC", (), {"pubsub_context": pubsub_context})()) | ||
| result = BenchmarkResult( | ||
| transport=transport_name, | ||
| duration=publish_end - start, | ||
| msgs_sent=msgs_sent, | ||
| msgs_received=final_received, | ||
| msg_size_bytes=msg_size, | ||
| receive_time=latency, | ||
| ) | ||
| benchmark_results.add(result) | ||
|
|
||
| # Warn if significant message loss (but don't fail - benchmark records the data) | ||
| loss_pct = (1 - final_received / msgs_sent) * 100 if msgs_sent > 0 else 0 | ||
| if loss_pct > 10: | ||
| import warnings | ||
|
|
||
| warnings.warn( | ||
| f"{transport_name} {msg_size}B: {loss_pct:.1f}% message loss " | ||
| f"({final_received}/{msgs_sent})", | ||
| stacklevel=2, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Creating a temporary instance with
type()()to call_get_ros_typeassumes the LCM type has a no-argument constructor. This will fail for types that require constructor arguments.Consider modifying
_get_ros_typeto accept a type instead of an instance, or handle the case where instantiation fails.