Skip to content

Commit afeebab

Browse files
shantanuparab-trlukeschmitt-tr
authored andcommitted
Optimize Inference Loop by Moving Observation Capture Inside Condition Block (#3)
optimize get_observation call
1 parent c204078 commit afeebab

1 file changed

Lines changed: 23 additions & 23 deletions

File tree

examples/trossen_ai/main.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -152,32 +152,32 @@ def run_episode(self, task_prompt: str = "look down"):
152152

153153
while self.is_running and self.episode_step < self.max_steps:
154154
start_loop_time = time.perf_counter()
155-
observation_dict = self.robot.get_observation()
156-
157-
# Extract joint positions from observation
158-
joint_pos_keys = [k for k in observation_dict.keys() if k.endswith('.pos')]
159-
joint_positions = np.array([observation_dict[k] for k in joint_pos_keys])
160-
161-
162-
# Transform and resize images from all cameras
163-
cameras = list(self.robot._cameras_ft.keys())
164-
for cam in cameras:
165-
image_hwc = observation_dict[cam]
166-
#convert BGR to RGB
167-
image_resized = cv2.resize(image_hwc, (224, 224))
168-
image_rgb = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB)
169-
image_chw = np.transpose(image_rgb, (2, 0, 1))
170-
observation_dict[cam] = image_chw
171-
172-
# Create observation for policy to follow the ALOHA format
173-
observation = {
174-
"state": joint_positions,
175-
"images": {cam: observation_dict[cam] for cam in cameras},
176-
"prompt": task_prompt
177-
}
178155

179156
# Request new action chunk after consuming the previous one
180157
if self.current_action_chunk is None or self.action_chunk_idx >= self.rate_of_inference:
158+
observation_dict = self.robot.get_observation()
159+
160+
# Extract joint positions from observation
161+
joint_pos_keys = [k for k in observation_dict.keys() if k.endswith('.pos')]
162+
joint_positions = np.array([observation_dict[k] for k in joint_pos_keys])
163+
164+
# Transform and resize images from all cameras
165+
cameras = list(self.robot._cameras_ft.keys())
166+
for cam in cameras:
167+
image_hwc = observation_dict[cam]
168+
#convert BGR to RGB
169+
image_resized = cv2.resize(image_hwc, (224, 224))
170+
image_rgb = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB)
171+
image_chw = np.transpose(image_rgb, (2, 0, 1))
172+
observation_dict[cam] = image_chw
173+
174+
# Create observation for policy to follow the ALOHA format
175+
observation = {
176+
"state": joint_positions,
177+
"images": {cam: observation_dict[cam] for cam in cameras},
178+
"prompt": task_prompt
179+
}
180+
181181
logger.info(f"Step {self.episode_step}: Requesting new action chunk")
182182
response = self.policy_client.infer(observation)
183183
self.current_action_chunk = response["actions"]

0 commit comments

Comments
 (0)