feat(api): add TPU utility to compute numNodes for multi-slice TPU#3532
feat(api): add TPU utility to compute numNodes for multi-slice TPU#3532richabanker wants to merge 1 commit into
Conversation
Signed-off-by: Richa Banker <richabanker@google.com>
|
🎉 Welcome to the Kubeflow Trainer! 🎉 Thanks for opening your first PR! We're happy to have you as part of our community 🚀 Here's what happens next:
Join the community:
Feel free to ask questions in the comments if you need any help or clarification! |
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
There was a problem hiding this comment.
Pull request overview
Adds a small TPU helper to the Python API package to compute the total host count (numNodes) for multi-slice TPU topologies, intended to simplify configuring multi-slice TPU TrainJobs (Issue #3407).
Changes:
- Added
get_num_nodes(num_slices, topology, chips_per_host=4)to compute total VM hosts across slices. - Added unit tests covering common 2D/3D TPU topologies and invalid inputs.
- Re-exported
get_num_nodesfromkubeflow_trainer_apipackage__init__.py.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
api/python_api/kubeflow_trainer_api/tpu.py |
Introduces TPU topology parsing + host count computation utility. |
api/python_api/kubeflow_trainer_api/tpu_test.py |
Adds unit tests for the new TPU utility function. |
api/python_api/kubeflow_trainer_api/__init__.py |
Exposes get_num_nodes as part of the package public surface. |
| if not topology: | ||
| raise ValueError("TPU topology must be specified.") | ||
|
|
||
| # Parse the topology dimensions (e.g. "2x2" or "2x2x2") | ||
| try: |
| dims = [int(d) for d in topology.lower().split("x")] | ||
| except ValueError: | ||
| raise ValueError( | ||
| f"Invalid topology format: '{topology}'. Must be formatted as 'AxB' or 'AxBxC' (e.g. '2x2', '2x2x2')." | ||
| ) |
| import unittest | ||
| from kubeflow_trainer_api.tpu import get_num_nodes | ||
|
|
||
| class TestTPUUtils(unittest.TestCase): |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| def get_num_nodes(num_slices: int, topology: str, chips_per_host: int = 4) -> int: |
There was a problem hiding this comment.
Thanks for this @richabanker!
I think, we should contribute this utility function to Kubeflow SDK: https://github.com/kubeflow/sdk/blob/main/kubeflow/trainer/backends/kubernetes/utils.py
Since kubeflow_trainer_api only exposes Python models for Trainer CRDs.
There was a problem hiding this comment.
Ah my bad, opened kubeflow/sdk#498 against the SDK repo.
Will close this one out. Thanks!
|
I agree with Andrey, this function should be in Kubeflow SDK. |
|
/close in favor of kubeflow/sdk#498 |
|
@richabanker: Closed this PR. DetailsIn response to this:
Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
What this PR does / why we need it:
Introduce a TPU utility function
get_num_nodesin the Python API package to calculate the total number of VM hosts (numNodes) for multi-slice TPU configurationsWhich issue(s) this PR fixes (optional, in
Fixes #<issue number>, #<issue number>, ...format, will close the issue(s) when PR gets merged):Issue ##3407
Checklist: