refactor: move _remap_outputs function to rfd3.py and remove from rf3.py#110
refactor: move _remap_outputs function to rfd3.py and remove from rf3.py#110Wolkenwandler wants to merge 1 commit intoRosettaCommons:productionfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors the _remap_outputs helper function by moving it from rf3.py (where it was unused) to rfd3.py (where it is needed). The function is used to remap network outputs according to a mapping tensor for alignment with ground truth data.
- Removed unused
_remap_outputsfunction frommodels/rf3/src/rf3/trainers/rf3.py - Added
_remap_outputsfunction tomodels/rfd3/src/rfd3/trainer/rfd3.pyalong with requiredjaxtypingimports
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| models/rfd3/src/rfd3/trainer/rfd3.py | Added _remap_outputs function and imported Float and Int from jaxtyping to support the function's type annotations |
| models/rf3/src/rf3/trainers/rf3.py | Removed unused _remap_outputs function definition |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) -> Float[torch.Tensor, "D L 3"]: | ||
| """Helper function to remap outputs using a mapping tensor.""" | ||
| for i in range(xyz.shape[0]): | ||
| xyz[i, mapping[i]] = xyz[i].clone() |
There was a problem hiding this comment.
The remapping logic is incorrect. The current implementation assigns the cloned original values to positions specified by the mapping, which doesn't perform the intended remapping operation. The line should be xyz[i] = xyz[i, mapping[i]] or xyz[i] = xyz[i].clone()[mapping[i]] to correctly reorder the elements according to the mapping indices. With the current implementation, if mapping contains indices [2, 0, 1], the function won't reorder the tensor elements correctly.
| xyz[i, mapping[i]] = xyz[i].clone() | |
| xyz[i] = xyz[i, mapping[i]] |
| def _remap_outputs( | ||
| xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"] | ||
| ) -> Float[torch.Tensor, "D L 3"]: | ||
| """Helper function to remap outputs using a mapping tensor.""" | ||
| for i in range(xyz.shape[0]): | ||
| xyz[i, mapping[i]] = xyz[i].clone() | ||
| return xyz |
There was a problem hiding this comment.
This function is duplicated from trainer_utils.py (lines 33-39) where an identical implementation already exists. Consider importing the function from trainer_utils instead to avoid code duplication and maintain a single source of truth. If there's a specific reason for the duplication, it should be documented.
There was a problem hiding this comment.
@copilot open a new pull request to apply changes based on this feedback
|
Was this causing issues? |
* refactor: delete files * fix: add back files that included imported functions in chiral code * chore: add back files to archive, update pyproject.yaml and gitignore
This function is missing in rfd3, but it exists in rf3 and is not used there.