Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions pufferlib/ocean/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(
# Road features size (lanes + boundaries)
self.obs_slots_lane_kept = env.obs_slots_lane_kept
self.obs_slots_boundary_kept = env.obs_slots_boundary_kept
self.obs_slots_lane_n = env.obs_slots_lane_n
self.obs_slots_boundary_n = env.obs_slots_boundary_n
self.road_features_count = env.road_features
# Traffic control size
self.obs_slots_traffic_controls_n = env.obs_slots_traffic_controls_n
Expand Down Expand Up @@ -116,9 +118,18 @@ def __init__(

def forward(self, observations, ego_dim):
# Extract and slice observations from the flat buffer

if self.training:
obs_slots_lane_kept = self.obs_slots_lane_kept
obs_slots_boundary_kept = self.obs_slots_boundary_kept
else:
# During evaluation, enforce zero dropout (also in pufferlib/ocean/benchmark/manager.py)
obs_slots_lane_kept = self.obs_slots_lane_n
obs_slots_boundary_kept = self.obs_slots_boundary_n

partner_dim = self.obs_slots_partners_n * self.partner_features_count
lane_dim = self.obs_slots_lane_kept * self.road_features_count
boundary_dim = self.obs_slots_boundary_kept * self.road_features_count
lane_dim = obs_slots_lane_kept * self.road_features_count
boundary_dim = obs_slots_boundary_kept * self.road_features_count
traffic_control_dim = self.obs_slots_traffic_controls_n * self.traffic_control_features_count

slide_idx = ego_dim
Expand All @@ -144,12 +155,12 @@ def forward(self, observations, ego_dim):
feature_list = [ego_features]

# Encode Lanes and Boundaries separately
if self.obs_slots_lane_kept > 0:
lane_objects = lane_observations.view(-1, self.obs_slots_lane_kept, self.road_features_count)
if obs_slots_lane_kept > 0:
lane_objects = lane_observations.view(-1, obs_slots_lane_kept, self.road_features_count)
lane_features = self.lane_encoder(lane_objects).max(dim=1).values
feature_list.append(lane_features)
if self.obs_slots_boundary_kept > 0:
boundary_objects = boundary_observations.view(-1, self.obs_slots_boundary_kept, self.road_features_count)
if obs_slots_boundary_kept > 0:
boundary_objects = boundary_observations.view(-1, obs_slots_boundary_kept, self.road_features_count)
boundary_features = self.boundary_encoder(boundary_objects).max(dim=1).values
feature_list.append(boundary_features)

Expand Down Expand Up @@ -192,9 +203,15 @@ def forward(self, observations, ego_dim):
return self.backbone(concat_features)

def pool_slot_counts(self, observations, ego_dim):
if self.training:
obs_slots_lane_kept = self.obs_slots_lane_kept
obs_slots_boundary_kept = self.obs_slots_boundary_kept
else:
obs_slots_lane_kept = self.obs_slots_lane_n
obs_slots_boundary_kept = self.obs_slots_boundary_n
partner_dim = self.obs_slots_partners_n * self.partner_features_count
lane_dim = self.obs_slots_lane_kept * self.road_features_count
boundary_dim = self.obs_slots_boundary_kept * self.road_features_count
lane_dim = obs_slots_lane_kept * self.road_features_count
boundary_dim = obs_slots_boundary_kept * self.road_features_count
traffic_control_dim = self.obs_slots_traffic_controls_n * self.traffic_control_features_count

slide_idx = ego_dim + self.conditioning_dim
Expand All @@ -207,18 +224,18 @@ def pool_slot_counts(self, observations, ego_dim):
traffic_control_observations = observations[:, slide_idx : slide_idx + traffic_control_dim]

counts = {}
if self.obs_slots_lane_kept > 0:
lane_objects = lane_observations.view(-1, self.obs_slots_lane_kept, self.road_features_count)
if obs_slots_lane_kept > 0:
lane_objects = lane_observations.view(-1, obs_slots_lane_kept, self.road_features_count)
lane_winners = self.lane_encoder(lane_objects).max(dim=1).indices
lane_counts = torch.zeros(
observations.shape[0], self.obs_slots_lane_kept, device=observations.device, dtype=torch.int64
observations.shape[0], obs_slots_lane_kept, device=observations.device, dtype=torch.int64
)
counts["pool_lane"] = lane_counts.scatter_add(1, lane_winners, torch.ones_like(lane_winners))
if self.obs_slots_boundary_kept > 0:
boundary_objects = boundary_observations.view(-1, self.obs_slots_boundary_kept, self.road_features_count)
if obs_slots_boundary_kept > 0:
boundary_objects = boundary_observations.view(-1, obs_slots_boundary_kept, self.road_features_count)
boundary_winners = self.boundary_encoder(boundary_objects).max(dim=1).indices
boundary_counts = torch.zeros(
observations.shape[0], self.obs_slots_boundary_kept, device=observations.device, dtype=torch.int64
observations.shape[0], obs_slots_boundary_kept, device=observations.device, dtype=torch.int64
)
counts["pool_boundary"] = boundary_counts.scatter_add(
1, boundary_winners, torch.ones_like(boundary_winners)
Expand Down
Loading