Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions dimos/core/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from dimos.core.module import Module
from dimos.core.module_coordinator import ModuleCoordinator
from dimos.core.stream import In, Out
from dimos.core.transport import LCMTransport, pLCMTransport
from dimos.core.transport import LCMTransport, ROSTransport, pLCMTransport
from dimos.protocol.pubsub.rospubsub import ROS_AVAILABLE
from dimos.utils.generic import short_id
from dimos.utils.logging_config import setup_logger

Expand Down Expand Up @@ -124,14 +125,28 @@ def _check_ambiguity(
f"{modules_str}. Please use a concrete class name instead."
)

def _get_transport_for(self, name: str, type: type) -> Any:
def _get_transport_for(self, name: str, type: type, backend: str = "lcm") -> Any:
transport = self.transport_map.get((name, type), None)
if transport:
return transport

use_pickled = getattr(type, "lcm_encode", None) is None
topic = f"/{name}" if self._is_name_unique(name) else f"/{short_id()}"
transport = pLCMTransport(topic) if use_pickled else LCMTransport(topic, type)

if backend == "ros":
if not ROS_AVAILABLE:
raise ImportError(
"ROS transport requested but rclpy not available. "
"Install ROS 2 or set default_transport='lcm'."
)
# For ROS, we need a ROS message type. The type here is an LCM type,
# so ROSTransport will handle the conversion internally.
from dimos.protocol.pubsub.rospubsub import _get_ros_type

ros_type = _get_ros_type(type()) # Get ROS type from an instance
Comment on lines +144 to +145
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Creating a temporary instance with type()() to call _get_ros_type assumes the LCM type has a no-argument constructor. This will fail for types that require constructor arguments.

Consider modifying _get_ros_type to accept a type instead of an instance, or handle the case where instantiation fails.

transport = ROSTransport(topic, ros_type)
else:
use_pickled = getattr(type, "lcm_encode", None) is None
transport = pLCMTransport(topic) if use_pickled else LCMTransport(topic, type)

return transport

Expand Down Expand Up @@ -209,7 +224,9 @@ def _deploy_all_modules(
kwargs["global_config"] = global_config
module_coordinator.deploy(blueprint.module, *blueprint.args, **kwargs)

def _connect_transports(self, module_coordinator: ModuleCoordinator) -> None:
def _connect_transports(
self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig
) -> None:
# Gather all the In/Out connections with remapping applied.
connections = defaultdict(list)
# Track original name -> remapped name for each module
Expand All @@ -225,8 +242,9 @@ def _connect_transports(self, module_coordinator: ModuleCoordinator) -> None:
connections[remapped_name, conn.type].append((blueprint.module, conn.name))

# Connect all In/Out connections by remapped name and type.
backend = global_config.default_transport
for remapped_name, type in connections.keys():
transport = self._get_transport_for(remapped_name, type)
transport = self._get_transport_for(remapped_name, type, backend)
for module, original_name in connections[(remapped_name, type)]:
instance = module_coordinator.get_instance(module)
instance.set_transport(original_name, transport) # type: ignore[union-attr]
Expand Down Expand Up @@ -386,7 +404,7 @@ def build(
module_coordinator.start()

self._deploy_all_modules(module_coordinator, global_config)
self._connect_transports(module_coordinator)
self._connect_transports(module_coordinator, global_config)
self._connect_rpc_methods(module_coordinator)

module_coordinator.start_all_modules()
Expand Down
2 changes: 2 additions & 0 deletions dimos/core/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from dimos.mapping.occupancy.path_map import NavigationStrategy

ViewerBackend: TypeAlias = Literal["rerun-web", "rerun-native", "foxglove"]
TransportBackend: TypeAlias = Literal["lcm", "ros"]


def _get_all_numbers(s: str) -> list[float]:
Expand Down Expand Up @@ -48,6 +49,7 @@ class GlobalConfig(BaseSettings):
robot_rotation_diameter: float = 0.6
planner_strategy: NavigationStrategy = "simple"
planner_robot_speed: float | None = None
default_transport: TransportBackend = "lcm"

model_config = SettingsConfigDict(
env_file=".env",
Expand Down
49 changes: 49 additions & 0 deletions dimos/core/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from dimos.core.stream import In, Out, Stream, Transport
from dimos.protocol.pubsub.jpeg_shm import JpegSharedMemory
from dimos.protocol.pubsub.lcmpubsub import LCM, JpegLCM, PickleLCM, Topic as LCMTopic
from dimos.protocol.pubsub.rospubsub import ROS_AVAILABLE, DimosROS, ROSTopic
from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory, SharedMemory

if TYPE_CHECKING:
Expand Down Expand Up @@ -125,6 +126,54 @@ def stop(self) -> None:
self.lcm.stop()


class ROSTransport(PubSubTransport[T]):
"""Transport that publishes/subscribes via ROS 2.

Automatically converts between dimos_lcm messages and ROS messages.

Usage:
from geometry_msgs.msg import Vector3 as ROSVector3

ROSTransport("/my_topic", ROSVector3)
"""

_started: bool = False
_ros: DimosROS | None = None

def __init__(self, topic: str, ros_type: type, **kwargs) -> None: # type: ignore[no-untyped-def]
if not ROS_AVAILABLE:
raise ImportError("ROS not available. Install rclpy to use ROSTransport.")
super().__init__(ROSTopic(topic, ros_type))
self._kwargs = kwargs

def _ensure_ros(self) -> DimosROS:
if self._ros is None:
self._ros = DimosROS(**self._kwargs)
return self._ros

def start(self) -> None:
if not self._started:
self._ensure_ros().start()
self._started = True

def stop(self) -> None:
if self._ros is not None:
self._ros.stop()

def __reduce__(self): # type: ignore[no-untyped-def]
return (ROSTransport, (self.topic.topic, self.topic.ros_type))

def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def]
"""Publish a dimos_lcm message to ROS (auto-converts to ROS message)."""
self.start()
self._ensure_ros().publish(self.topic, msg)

def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None: # type: ignore[assignment, override]
"""Subscribe to ROS topic (auto-converts ROS messages to dimos_lcm)."""
self.start()
return self._ensure_ros().subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[return-value]

Comment on lines +129 to +175
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Each ROSTransport instance creates its own DimosROS node. When multiple topics use ROS transport, this creates multiple ROS nodes which is resource-intensive and can cause node name conflicts.

Consider using a singleton pattern or shared ROS node instance across all ROSTransport instances, similar to how LCM shares a single LCM instance.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


class pSHMTransport(PubSubTransport[T]):
_started: bool = False

Expand Down
159 changes: 159 additions & 0 deletions dimos/protocol/pubsub/benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#!/usr/bin/env python3

# Copyright 2025-2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time

import pytest

from dimos.protocol.pubsub.benchmark.testdata import testdata
from dimos.protocol.pubsub.benchmark.type import BenchmarkResult, BenchmarkResults

# Message sizes for throughput benchmarking (powers of 2 from 64B to 10MB)
MSG_SIZES = [
64,
256,
1024,
4096,
16384,
65536,
262144,
524288,
1048576,
1048576 * 2,
1048576 * 5,
1048576 * 10,
]

# Benchmark duration in seconds
BENCH_DURATION = 1.0

# Max messages to send per test (prevents overwhelming slower transports)
MAX_MESSAGES = 5000

# Max time to wait for in-flight messages after publishing stops
RECEIVE_TIMEOUT = 1.0


def size_id(size: int) -> str:
"""Convert byte size to human-readable string for test IDs."""
if size >= 1048576:
return f"{size // 1048576}MB"
if size >= 1024:
return f"{size // 1024}KB"
return f"{size}B"


def pubsub_id(testcase) -> str:
"""Extract pubsub implementation name from context manager function name."""
name = testcase.pubsub_context.__name__
# Convert e.g. "lcm_pubsub_channel" -> "LCM", "memory_pubsub_channel" -> "Memory"
prefix = name.replace("_pubsub_channel", "").replace("_", " ")
return prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "")


@pytest.fixture(scope="module")
def benchmark_results():
"""Module-scoped fixture to collect benchmark results."""
results = BenchmarkResults()
yield results
results.print_summary()
results.print_heatmap()
results.print_bandwidth_heatmap()
results.print_latency_heatmap()


@pytest.mark.benchmark
@pytest.mark.parametrize("msg_size", MSG_SIZES, ids=[size_id(s) for s in MSG_SIZES])
@pytest.mark.parametrize("pubsub_context, msggen", testdata, ids=[pubsub_id(t) for t in testdata])
def test_throughput(pubsub_context, msggen, msg_size, benchmark_results):
"""Measure throughput for publishing and receiving messages over a fixed duration."""
with pubsub_context() as pubsub:
topic, msg = msggen(msg_size)
received_count = 0
target_count = [0] # Use list to allow modification after publish loop
lock = threading.Lock()
all_received = threading.Event()

def callback(message, _topic):
nonlocal received_count
with lock:
received_count += 1
if target_count[0] > 0 and received_count >= target_count[0]:
all_received.set()

# Subscribe
pubsub.subscribe(topic, callback)

# Warmup: give DDS/ROS time to establish connection
time.sleep(0.1)

# Set target so callback can signal when all received
target_count[0] = MAX_MESSAGES

# Publish messages until time limit, max messages, or all received
msgs_sent = 0
start = time.perf_counter()
end_time = start + BENCH_DURATION

while time.perf_counter() < end_time and msgs_sent < MAX_MESSAGES:
pubsub.publish(topic, msg)
msgs_sent += 1
# Check if all already received (fast transports)
if all_received.is_set():
break

publish_end = time.perf_counter()
target_count[0] = msgs_sent # Update to actual sent count

# Check if already done, otherwise wait up to RECEIVE_TIMEOUT
with lock:
if received_count >= msgs_sent:
all_received.set()

if not all_received.is_set():
all_received.wait(timeout=RECEIVE_TIMEOUT)
latency_end = time.perf_counter()

with lock:
final_received = received_count

# Latency: how long we waited after publishing for messages to arrive
# 0 = all arrived during publishing, 1000ms = hit timeout (loss occurred)
latency = latency_end - publish_end

# Record result (duration is publish time only for throughput calculation)
transport_name = pubsub_id(type("TC", (), {"pubsub_context": pubsub_context})())
result = BenchmarkResult(
transport=transport_name,
duration=publish_end - start,
msgs_sent=msgs_sent,
msgs_received=final_received,
msg_size_bytes=msg_size,
receive_time=latency,
)
benchmark_results.add(result)

# Warn if significant message loss (but don't fail - benchmark records the data)
loss_pct = (1 - final_received / msgs_sent) * 100 if msgs_sent > 0 else 0
if loss_pct > 10:
import warnings

warnings.warn(
f"{transport_name} {msg_size}B: {loss_pct:.1f}% message loss "
f"({final_received}/{msgs_sent})",
stacklevel=2,
)
Loading
Loading