Skip to content

Conversation

@romerojosh
Copy link
Collaborator

In related development to #99, this PR makes some additional device handling improvements in the TorchFort backend. In particular, this PR adds:

  1. Usage of CUDAGuard objects within supervised and RL functions to properly set/unset the current CUDA device to expected model device. This is not fixing any current issue in the implementation but better sets up the code for direct CUDA runtime call utilization (e.g. CUDA graph capture/replay).
  2. Checks on the user-supplied CUDA stream to ensure it is on the same device as the model.

@romerojosh
Copy link
Collaborator Author

/build_and_test

@github-actions
Copy link

github-actions bot commented Dec 8, 2025

🚀 Build workflow triggered! View run

@github-actions
Copy link

github-actions bot commented Dec 8, 2025

✅ Build workflow passed! View run

@romerojosh romerojosh requested a review from azrael417 December 8, 2025 19:01
Copy link
Collaborator

@azrael417 azrael417 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. What are the remaining issues we need to take care of.

@romerojosh romerojosh merged commit 9c82a64 into master Jan 5, 2026
4 checks passed
@romerojosh romerojosh deleted the device_handling_improvements branch January 6, 2026 17:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants