From 6e0898a907be01b483348f1cb78a531779485b9f Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Thu, 25 Sep 2025 18:00:31 +0000 Subject: [PATCH 01/63] Update entrypoint for jaii Test maxdiffusion workload on gpu image --- .github/workflows/UnitTests.yml | 99 ++++++++++++++++-------- maxdiffusion_jax_ai_image_tpu.Dockerfile | 2 +- 2 files changed, 66 insertions(+), 35 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2c588b43..f8544c47 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -22,43 +22,74 @@ on: push: branches: [ "main" ] workflow_dispatch: - schedule: - # Run the job every 12 hours - - cron: '0 */12 * * *' jobs: - build: - strategy: - fail-fast: false - matrix: - tpu-type: ["v5p-8"] - name: "TPU test (${{ matrix.tpu-type }})" - runs-on: ["self-hosted","${{ matrix.tpu-type }}"] + # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD + maxdiffusion_workload: + name: "Run MaxDiffusion Workload" + # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + runs-on: ["self-hosted", "linux-x86-a2-48-a100-4gpu"] + container: + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest steps: - - uses: actions/checkout@v4 - - name: Set up Python 3.12 - uses: actions/setup-python@v5 - with: - python-version: '3.12' - - name: Install dependencies - run: | - pip install -e . - pip uninstall jax jaxlib libtpu-nightly libtpu -y - bash setup.sh MODE=stable - export PATH=$PATH:$HOME/.local/bin - pip install ruff - pip install isort - pip install pytest - - name: Analysing the code with ruff - run: | - ruff check . - - name: version check - run: | - python --version - pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets - - name: PyTest - run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Run MaxDiffusion Training + run: | + # This command is adapted from your DAG for a single-slice configuration. + JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true \ + TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true \ + JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && \ + pip install . && \ + python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \ + pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \ + revision=refs/pr/95 \ + activations_dtype=bfloat16 \ + weights_dtype=bfloat16 \ + dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl \ + resolution=1024 \ + per_device_batch_size=1 \ + jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ \ + max_train_steps=20 \ + attention=flash \ + enable_profiler=True \ + run_name=1slice-maxdiffusion-stable-stack-${{ github.run_id }} \ + output_dir=gs://your-output-bucket/maxdiffusion-jax-stable-stack/automated/${{ github.run_id }} + +# jobs: +# build: +# strategy: +# fail-fast: false +# matrix: +# tpu-type: ["v5p-8"] +# name: "TPU test (${{ matrix.tpu-type }})" +# runs-on: ["self-hosted","${{ matrix.tpu-type }}"] +# steps: +# - uses: actions/checkout@v4 +# - name: Set up Python 3.12 +# uses: actions/setup-python@v5 +# with: +# python-version: '3.12' +# - name: Install dependencies +# run: | +# pip install -e . +# pip uninstall jax jaxlib libtpu-nightly libtpu -y +# bash setup.sh MODE=stable +# export PATH=$PATH:$HOME/.local/bin +# pip install ruff +# pip install isort +# pip install pytest +# - name: Analysing the code with ruff +# run: | +# ruff check . +# - name: version check +# run: | +# python --version +# pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets +# - name: PyTest +# run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py +# HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/maxdiffusion_jax_ai_image_tpu.Dockerfile b/maxdiffusion_jax_ai_image_tpu.Dockerfile index cab50fee..301f9b88 100644 --- a/maxdiffusion_jax_ai_image_tpu.Dockerfile +++ b/maxdiffusion_jax_ai_image_tpu.Dockerfile @@ -19,4 +19,4 @@ COPY . . RUN pip install -r /deps/requirements_with_jax_ai_image.txt # Run the script available in JAX-AI-Image base image to generate the manifest file -RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file +RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file From c50ab1b036752a4980aa80d05df8d60bb3a70af4 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 15:52:17 +0000 Subject: [PATCH 02/63] remove self-hosted tag --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index f8544c47..a216fd35 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -28,7 +28,7 @@ jobs: maxdiffusion_workload: name: "Run MaxDiffusion Workload" # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - runs-on: ["self-hosted", "linux-x86-a2-48-a100-4gpu"] + runs-on: ["linux-x86-a2-48-a100-4gpu"] container: image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest steps: From 1f22cf53f0795c8f113316f90c94b3bdf744859c Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 15:53:26 +0000 Subject: [PATCH 03/63] wrong tag name --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index a216fd35..bd5cd29e 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -28,7 +28,7 @@ jobs: maxdiffusion_workload: name: "Run MaxDiffusion Workload" # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - runs-on: ["linux-x86-a2-48-a100-4gpu"] + runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest steps: From 2deaa7e9fdc0e0e83fed085acb8bb1d04d6261ad Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 16:09:46 +0000 Subject: [PATCH 04/63] Update with command for gpu --- .github/workflows/UnitTests.yml | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index bd5cd29e..d8695c33 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -38,24 +38,20 @@ jobs: - name: Run MaxDiffusion Training run: | # This command is adapted from your DAG for a single-slice configuration. - JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true \ - TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true \ - JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && \ - pip install . && \ - python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \ - pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \ - revision=refs/pr/95 \ + NVTE_FUSED_ATTN=1 pip install . && \ + python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + hardware=gpu \ + train_new_unet=true \ + train_text_encoder=false \ + cache_latents_text_encoder_outputs=true \ + per_device_batch_size=1 \ + attention=cudnn_flash_te \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ - dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl \ - resolution=1024 \ - per_device_batch_size=1 \ - jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ \ - max_train_steps=20 \ - attention=flash \ + max_train_steps=200 \ enable_profiler=True \ - run_name=1slice-maxdiffusion-stable-stack-${{ github.run_id }} \ - output_dir=gs://your-output-bucket/maxdiffusion-jax-stable-stack/automated/${{ github.run_id }} + run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ + output_dir=gs://ml-auto-solutions/output/maxdiffusion/automated/maxdiffusion_sdxl/${{ github.run_id }} # jobs: # build: From d47a6ebeff3c70aa10fbe07eafea030b075e9da6 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 16:28:34 +0000 Subject: [PATCH 05/63] remove pip install . --- .github/workflows/UnitTests.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index d8695c33..20fad808 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -38,8 +38,7 @@ jobs: - name: Run MaxDiffusion Training run: | # This command is adapted from your DAG for a single-slice configuration. - NVTE_FUSED_ATTN=1 pip install . && \ - python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ hardware=gpu \ train_new_unet=true \ train_text_encoder=false \ From a6579ecd68ff5ade6b8be0f9ac73f1da8e196344 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 16:30:01 +0000 Subject: [PATCH 06/63] Check environment --- .github/workflows/UnitTests.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 20fad808..5704c986 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -35,10 +35,15 @@ jobs: - name: Checkout Repository uses: actions/checkout@v4 + - name: Print dependencies + run: | + pip freeze + - name: Run MaxDiffusion Training run: | # This command is adapted from your DAG for a single-slice configuration. - NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + NVTE_FUSED_ATTN=1 pip install . && \ + python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ hardware=gpu \ train_new_unet=true \ train_text_encoder=false \ From cc5d322f2f8ebdc3f9008f8ceab0681eb0adff0f Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 16:41:14 +0000 Subject: [PATCH 07/63] Update to right image --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 5704c986..95e261d3 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -30,7 +30,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu steps: - name: Checkout Repository uses: actions/checkout@v4 From 26a03c0bf1cbcbbf6100136031de4783bfd3b41b Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 16:47:49 +0000 Subject: [PATCH 08/63] point to allowed bucket --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 95e261d3..7d36ca48 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -55,7 +55,7 @@ jobs: max_train_steps=200 \ enable_profiler=True \ run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - output_dir=gs://ml-auto-solutions/output/maxdiffusion/automated/maxdiffusion_sdxl/${{ github.run_id }} + output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} # jobs: # build: From 9236f6fcfd6ec2911253dead691daf239f88a7fe Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 17:25:38 +0000 Subject: [PATCH 09/63] Install entire TE --- .github/workflows/UnitTests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 7d36ca48..86326037 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -37,6 +37,8 @@ jobs: - name: Print dependencies run: | + pip uninstall -y transformer-engine + pip install transformer-engine pip freeze - name: Run MaxDiffusion Training From 4e0dd7f1df2a1ed8f3b3e87c4bb1c875e8235be9 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 17:30:24 +0000 Subject: [PATCH 10/63] uninstall all TE deps --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 86326037..31a17048 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -37,7 +37,7 @@ jobs: - name: Print dependencies run: | - pip uninstall -y transformer-engine + pip uninstall -y transformer-engine transformer-engine-jax pip install transformer-engine pip freeze From 08883164629b931409a5d6d9fd8f2e4c5b1ee306 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 17:34:52 +0000 Subject: [PATCH 11/63] install TE jax --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 31a17048..57475c91 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -38,7 +38,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install transformer-engine + pip install transformer-engine transformer-engine-jax pip freeze - name: Run MaxDiffusion Training From 09c5d6a04212970db02dd7eafde6d6c1f0bb91a4 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 17:35:39 +0000 Subject: [PATCH 12/63] fix pip instal typo --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 57475c91..bb3c0a22 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -38,7 +38,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install transformer-engine transformer-engine-jax + pip install transformer-engine transformer_engine[jax] pip freeze - name: Run MaxDiffusion Training From 846e1768e7d512861553b78c59c1c1cac484161b Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 17:43:00 +0000 Subject: [PATCH 13/63] Install TE for pytorch and jax --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index bb3c0a22..4c721dec 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -38,7 +38,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install transformer-engine transformer_engine[jax] + pip install -U transformer-engine[pytorch,jax] pip freeze - name: Run MaxDiffusion Training From 496f67cfb2d06772a2cbe7aa398836306ecbf27a Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:02:16 +0000 Subject: [PATCH 14/63] Comment out tflop calc --- .github/workflows/UnitTests.yml | 4 ++-- src/maxdiffusion/trainers/base_stable_diffusion_trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 4c721dec..4eb3f862 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -37,8 +37,8 @@ jobs: - name: Print dependencies run: | - pip uninstall -y transformer-engine transformer-engine-jax - pip install -U transformer-engine[pytorch,jax] + # pip uninstall -y transformer-engine transformer-engine-jax + # pip install -U transformer-engine[pytorch,jax] pip freeze - name: Run MaxDiffusion Training diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index a9f17adc..e889d816 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -161,8 +161,8 @@ def start_training(self): params["scheduler"] = noise_scheduler_state # Calculate tflops - per_device_tflops = self.calculate_tflops(pipeline, params) - self.per_device_tflops = per_device_tflops + # per_device_tflops = self.calculate_tflops(pipeline, params) + # self.per_device_tflops = per_device_tflops # Load dataset data_iterator = self._time_and_log_call(self.load_dataset, pipeline, params, train_states) From 37524bba47363e300110a9da4f036de1df378e4f Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:09:40 +0000 Subject: [PATCH 15/63] Test with dot_product attention --- .github/workflows/UnitTests.yml | 2 +- src/maxdiffusion/trainers/base_stable_diffusion_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 4eb3f862..2f16eb7f 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -51,7 +51,7 @@ jobs: train_text_encoder=false \ cache_latents_text_encoder_outputs=true \ per_device_batch_size=1 \ - attention=cudnn_flash_te \ + attention=dot_product \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ max_train_steps=200 \ diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index e889d816..a9f17adc 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -161,8 +161,8 @@ def start_training(self): params["scheduler"] = noise_scheduler_state # Calculate tflops - # per_device_tflops = self.calculate_tflops(pipeline, params) - # self.per_device_tflops = per_device_tflops + per_device_tflops = self.calculate_tflops(pipeline, params) + self.per_device_tflops = per_device_tflops # Load dataset data_iterator = self._time_and_log_call(self.load_dataset, pipeline, params, train_states) From 64f30f36bbc427ffab039a9ef4dfa56429d0f121 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:45:42 +0000 Subject: [PATCH 16/63] Test if maxtext has same gpu issue --- .github/workflows/UnitTests.yml | 78 ++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2f16eb7f..2980841e 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -24,40 +24,66 @@ on: workflow_dispatch: jobs: - # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD - maxdiffusion_workload: - name: "Run MaxDiffusion Workload" + maxtext_workload: + name: "Run MaxText Workload" # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu + image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest steps: - - name: Checkout Repository + - name: Checkout MaxText Repo uses: actions/checkout@v4 + with: + repository: AI-Hypercomputer/maxtext + path: maxtext - - name: Print dependencies - run: | - # pip uninstall -y transformer-engine transformer-engine-jax - # pip install -U transformer-engine[pytorch,jax] - pip freeze - - - name: Run MaxDiffusion Training + - name: Run MaxText Training run: | # This command is adapted from your DAG for a single-slice configuration. - NVTE_FUSED_ATTN=1 pip install . && \ - python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ - hardware=gpu \ - train_new_unet=true \ - train_text_encoder=false \ - cache_latents_text_encoder_outputs=true \ - per_device_batch_size=1 \ - attention=dot_product \ - activations_dtype=bfloat16 \ - weights_dtype=bfloat16 \ - max_train_steps=200 \ - enable_profiler=True \ - run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} + cd maxtext && \ + pip install -e . --no-dependencies \ + XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true \ + python3 -m MaxText.train MaxText/configs/base.yml \ + steps=2 \ + enable_checkpointing=false \ + attention=dot_product \ + run_name=rbierneni-test-maxtext-gpu \ + base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} + + # # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD + # maxdiffusion_workload: + # name: "Run MaxDiffusion Workload" + # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] + # container: + # image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu + # steps: + # - name: Checkout Repository + # uses: actions/checkout@v4 + + # - name: Print dependencies + # run: | + # # pip uninstall -y transformer-engine transformer-engine-jax + # # pip install -U transformer-engine[pytorch,jax] + # pip freeze + + # - name: Run MaxDiffusion Training + # run: | + # # This command is adapted from your DAG for a single-slice configuration. + # NVTE_FUSED_ATTN=1 pip install . && \ + # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + # hardware=gpu \ + # train_new_unet=true \ + # train_text_encoder=false \ + # cache_latents_text_encoder_outputs=true \ + # per_device_batch_size=1 \ + # attention=dot_product \ + # activations_dtype=bfloat16 \ + # weights_dtype=bfloat16 \ + # max_train_steps=200 \ + # enable_profiler=True \ + # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ + # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} # jobs: # build: From 53d1d22711f9e5a9bde3695d660aa079021c156b Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:48:27 +0000 Subject: [PATCH 17/63] change to pip install maxtext package --- .github/workflows/UnitTests.yml | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2980841e..b60c251b 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -41,14 +41,17 @@ jobs: run: | # This command is adapted from your DAG for a single-slice configuration. cd maxtext && \ - pip install -e . --no-dependencies \ - XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true \ + pip install . + + export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 + export TF_FORCE_GPU_ALLOW_GROWTH=true + python3 -m MaxText.train MaxText/configs/base.yml \ - steps=2 \ - enable_checkpointing=false \ - attention=dot_product \ - run_name=rbierneni-test-maxtext-gpu \ - base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} + steps=2 \ + enable_checkpointing=false \ + attention=dot_product \ + run_name=rbierneni-test-maxtext-gpu \ + base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} # # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD # maxdiffusion_workload: From fdcec4c018efd261a5c1866ded85c76f2fdb1c48 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:52:28 +0000 Subject: [PATCH 18/63] Install with no dependencies --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index b60c251b..c1033069 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -41,7 +41,7 @@ jobs: run: | # This command is adapted from your DAG for a single-slice configuration. cd maxtext && \ - pip install . + pip install . --no-dependencies export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 export TF_FORCE_GPU_ALLOW_GROWTH=true From cd9daaf19d7c46c246a09ec6cc16d5cec3312cf1 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:59:16 +0000 Subject: [PATCH 19/63] Use custom maxtext branch --- .github/workflows/UnitTests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index c1033069..686704b7 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -36,6 +36,7 @@ jobs: with: repository: AI-Hypercomputer/maxtext path: maxtext + ref: rbierneni-test-gpu-run - name: Run MaxText Training run: | From e0615dcf61746e0d010b224f0570e10035955bd9 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 19:04:45 +0000 Subject: [PATCH 20/63] use synthetic data --- .github/workflows/UnitTests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 686704b7..61e7c637 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -51,6 +51,7 @@ jobs: steps=2 \ enable_checkpointing=false \ attention=dot_product \ + dataset_type=synthetic \ run_name=rbierneni-test-maxtext-gpu \ base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} From 323654a7d18329696923cf90ec55a6e87663c218 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 19:10:37 +0000 Subject: [PATCH 21/63] Try with TE flash --- .github/workflows/UnitTests.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 61e7c637..49c22a75 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -37,6 +37,10 @@ jobs: repository: AI-Hypercomputer/maxtext path: maxtext ref: rbierneni-test-gpu-run + + - name: Print dependencies + run: | + pip freeze - name: Run MaxText Training run: | @@ -50,7 +54,7 @@ jobs: python3 -m MaxText.train MaxText/configs/base.yml \ steps=2 \ enable_checkpointing=false \ - attention=dot_product \ + attention=cudnn_flash_te \ dataset_type=synthetic \ run_name=rbierneni-test-maxtext-gpu \ base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} From 2313d541dd1705beec1b22b794b721dd95f489bd Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 21:07:21 +0000 Subject: [PATCH 22/63] Try with older TE --- .github/workflows/UnitTests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 49c22a75..ebf85225 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -40,6 +40,8 @@ jobs: - name: Print dependencies run: | + pip uninstall -y transformer-engine transformer-engine-jax + pip install -U transformer-engine[jax]=0.2.4 pip freeze - name: Run MaxText Training From 2eafa9329ec0d8e23ae878644fd2b7745f0cbdda Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 21:15:01 +0000 Subject: [PATCH 23/63] Typo in TE install --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index ebf85225..cb283a7e 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -41,7 +41,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install -U transformer-engine[jax]=0.2.4 + pip install -U transformer-engine[jax]==2.4.0 pip freeze - name: Run MaxText Training From a1da601045d60dd274df2c468f28f9e8dc60080e Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 21:16:11 +0000 Subject: [PATCH 24/63] Use TE 2.5.0 --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index cb283a7e..dfd71fa5 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -41,7 +41,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install -U transformer-engine[jax]==2.4.0 + pip install -U transformer-engine[jax]==2.5.0 pip freeze - name: Run MaxText Training From dfe8fc5bff63df3224712f2471a4d1f517ccabf3 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 21:23:40 +0000 Subject: [PATCH 25/63] Use TE 2.6.0 --- .github/workflows/UnitTests.yml | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index dfd71fa5..2209d61c 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -41,7 +41,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install -U transformer-engine[jax]==2.5.0 + pip install -U transformer-engine[jax]==2.6.0 pip freeze - name: Run MaxText Training @@ -81,20 +81,20 @@ jobs: # - name: Run MaxDiffusion Training # run: | # # This command is adapted from your DAG for a single-slice configuration. - # NVTE_FUSED_ATTN=1 pip install . && \ - # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ - # hardware=gpu \ - # train_new_unet=true \ - # train_text_encoder=false \ - # cache_latents_text_encoder_outputs=true \ - # per_device_batch_size=1 \ - # attention=dot_product \ - # activations_dtype=bfloat16 \ - # weights_dtype=bfloat16 \ - # max_train_steps=200 \ - # enable_profiler=True \ - # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} + # NVTE_FUSED_ATTN=1 pip install . && \ + # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + # hardware=gpu \ + # train_new_unet=true \ + # train_text_encoder=false \ + # cache_latents_text_encoder_outputs=true \ + # per_device_batch_size=1 \ + # attention=dot_product \ + # activations_dtype=bfloat16 \ + # weights_dtype=bfloat16 \ + # max_train_steps=200 \ + # enable_profiler=True \ + # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ + # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} # jobs: # build: From 4243d537ce26820f5671c70e84fe00765035bb8c Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 21:32:46 +0000 Subject: [PATCH 26/63] Test with tensorflow-cpu --- .github/workflows/UnitTests.yml | 124 ++++++++++++++++---------------- 1 file changed, 64 insertions(+), 60 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2209d61c..99a07e2f 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -24,77 +24,81 @@ on: workflow_dispatch: jobs: - maxtext_workload: - name: "Run MaxText Workload" - # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] - container: - image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest - steps: - - name: Checkout MaxText Repo - uses: actions/checkout@v4 - with: - repository: AI-Hypercomputer/maxtext - path: maxtext - ref: rbierneni-test-gpu-run - - - name: Print dependencies - run: | - pip uninstall -y transformer-engine transformer-engine-jax - pip install -U transformer-engine[jax]==2.6.0 - pip freeze - - - name: Run MaxText Training - run: | - # This command is adapted from your DAG for a single-slice configuration. - cd maxtext && \ - pip install . --no-dependencies - - export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 - export TF_FORCE_GPU_ALLOW_GROWTH=true - - python3 -m MaxText.train MaxText/configs/base.yml \ - steps=2 \ - enable_checkpointing=false \ - attention=cudnn_flash_te \ - dataset_type=synthetic \ - run_name=rbierneni-test-maxtext-gpu \ - base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} - - # # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD - # maxdiffusion_workload: - # name: "Run MaxDiffusion Workload" + # maxtext_workload: + # name: "Run MaxText Workload" # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] # container: - # image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu + # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest # steps: - # - name: Checkout Repository + # - name: Checkout MaxText Repo # uses: actions/checkout@v4 - + # with: + # repository: AI-Hypercomputer/maxtext + # path: maxtext + # ref: rbierneni-test-gpu-run + # - name: Print dependencies # run: | - # # pip uninstall -y transformer-engine transformer-engine-jax - # # pip install -U transformer-engine[pytorch,jax] + # pip uninstall -y transformer-engine transformer-engine-jax + # pip install -U transformer-engine[jax]==2.6.0 + # pip uninstall -y tensorflow + # pip install tensorflow-cpu # pip freeze - # - name: Run MaxDiffusion Training + # - name: Run MaxText Training # run: | # # This command is adapted from your DAG for a single-slice configuration. - # NVTE_FUSED_ATTN=1 pip install . && \ - # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ - # hardware=gpu \ - # train_new_unet=true \ - # train_text_encoder=false \ - # cache_latents_text_encoder_outputs=true \ - # per_device_batch_size=1 \ - # attention=dot_product \ - # activations_dtype=bfloat16 \ - # weights_dtype=bfloat16 \ - # max_train_steps=200 \ - # enable_profiler=True \ - # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} + # cd maxtext && \ + # pip install . --no-dependencies + + # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 + # export TF_FORCE_GPU_ALLOW_GROWTH=true + + # python3 -m MaxText.train MaxText/configs/base.yml \ + # steps=2 \ + # enable_checkpointing=false \ + # attention=cudnn_flash_te \ + # dataset_type=synthetic \ + # run_name=rbierneni-test-maxtext-gpu \ + # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} + + # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD + maxdiffusion_workload: + name: "Run MaxDiffusion Workload" + # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] + container: + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Print dependencies + run: | + # pip uninstall -y transformer-engine transformer-engine-jax + # pip install -U transformer-engine[pytorch,jax] + pip uninstall -y tensorflow + pip install tensorflow-cpu + pip freeze + + - name: Run MaxDiffusion Training + run: | + # This command is adapted from your DAG for a single-slice configuration. + NVTE_FUSED_ATTN=1 pip install . && \ + python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + hardware=gpu \ + train_new_unet=true \ + train_text_encoder=false \ + cache_latents_text_encoder_outputs=true \ + per_device_batch_size=1 \ + attention=dot_product \ + activations_dtype=bfloat16 \ + weights_dtype=bfloat16 \ + max_train_steps=200 \ + enable_profiler=True \ + run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ + output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} # jobs: # build: From d37e0dd775ed74d39c8b66f253df967ce0cea64b Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 00:22:37 +0000 Subject: [PATCH 27/63] Test with new cuda13 images and TE 2.6.0 --- .github/workflows/UnitTests.yml | 78 ++++++++++++++++----------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 99a07e2f..b5c0f0c3 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -24,44 +24,44 @@ on: workflow_dispatch: jobs: - # maxtext_workload: - # name: "Run MaxText Workload" - # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] - # container: - # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest - # steps: - # - name: Checkout MaxText Repo - # uses: actions/checkout@v4 - # with: - # repository: AI-Hypercomputer/maxtext - # path: maxtext - # ref: rbierneni-test-gpu-run + maxtext_workload: + name: "Run MaxText Workload" + # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] + container: + image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest + steps: + - name: Checkout MaxText Repo + uses: actions/checkout@v4 + with: + repository: AI-Hypercomputer/maxtext + path: maxtext + ref: rbierneni-test-gpu-run - # - name: Print dependencies - # run: | - # pip uninstall -y transformer-engine transformer-engine-jax - # pip install -U transformer-engine[jax]==2.6.0 - # pip uninstall -y tensorflow - # pip install tensorflow-cpu - # pip freeze + - name: Print dependencies + run: | + pip uninstall -y transformer-engine transformer-engine-jax + pip install -U transformer-engine[jax]==2.6.0 + # pip uninstall -y tensorflow + # pip install tensorflow-cpu + pip freeze - # - name: Run MaxText Training - # run: | - # # This command is adapted from your DAG for a single-slice configuration. - # cd maxtext && \ - # pip install . --no-dependencies + - name: Run MaxText Training + run: | + # This command is adapted from your DAG for a single-slice configuration. + cd maxtext && \ + pip install . --no-dependencies - # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 - # export TF_FORCE_GPU_ALLOW_GROWTH=true + export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 + export TF_FORCE_GPU_ALLOW_GROWTH=true - # python3 -m MaxText.train MaxText/configs/base.yml \ - # steps=2 \ - # enable_checkpointing=false \ - # attention=cudnn_flash_te \ - # dataset_type=synthetic \ - # run_name=rbierneni-test-maxtext-gpu \ - # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} + python3 -m MaxText.train MaxText/configs/base.yml \ + steps=2 \ + enable_checkpointing=false \ + attention=cudnn_flash_te \ + dataset_type=synthetic \ + run_name=rbierneni-test-maxtext-gpu \ + base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD maxdiffusion_workload: @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev2_gpu steps: - name: Checkout Repository uses: actions/checkout@v4 @@ -77,9 +77,9 @@ jobs: - name: Print dependencies run: | # pip uninstall -y transformer-engine transformer-engine-jax - # pip install -U transformer-engine[pytorch,jax] - pip uninstall -y tensorflow - pip install tensorflow-cpu + pip install -U transformer-engine[jax]==2.6.0 + # pip uninstall -y tensorflow + # pip install tensorflow-cpu pip freeze - name: Run MaxDiffusion Training @@ -92,7 +92,7 @@ jobs: train_text_encoder=false \ cache_latents_text_encoder_outputs=true \ per_device_batch_size=1 \ - attention=dot_product \ + attention=cudnn_flash_te \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ max_train_steps=200 \ From c247bee21d3e22756afe03f332cb9d733e5f551e Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 02:19:55 +0000 Subject: [PATCH 28/63] Update to right image --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index b5c0f0c3..37df849b 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev2_gpu + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev3_gpu steps: - name: Checkout Repository uses: actions/checkout@v4 From 67141daed96c6ecb0a6424434a7af24d9d8ef2e3 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 02:29:23 +0000 Subject: [PATCH 29/63] uninstall cuda12 TE as well --- .github/workflows/UnitTests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 37df849b..59e4a723 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -40,7 +40,7 @@ jobs: - name: Print dependencies run: | - pip uninstall -y transformer-engine transformer-engine-jax + pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y tensorflow # pip install tensorflow-cpu @@ -76,7 +76,7 @@ jobs: - name: Print dependencies run: | - # pip uninstall -y transformer-engine transformer-engine-jax + # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y tensorflow # pip install tensorflow-cpu From 4ac11bb6bd82915f402d8ded35702a3a08c88fa8 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 02:49:52 +0000 Subject: [PATCH 30/63] Test with te-cu13 package --- .github/workflows/UnitTests.yml | 70 +++++++++++++++++---------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 59e4a723..a43497ba 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -24,44 +24,44 @@ on: workflow_dispatch: jobs: - maxtext_workload: - name: "Run MaxText Workload" - # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] - container: - image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest - steps: - - name: Checkout MaxText Repo - uses: actions/checkout@v4 - with: - repository: AI-Hypercomputer/maxtext - path: maxtext - ref: rbierneni-test-gpu-run + # maxtext_workload: + # name: "Run MaxText Workload" + # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] + # container: + # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest + # steps: + # - name: Checkout MaxText Repo + # uses: actions/checkout@v4 + # with: + # repository: AI-Hypercomputer/maxtext + # path: maxtext + # ref: rbierneni-test-gpu-run - - name: Print dependencies - run: | - pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install -U transformer-engine[jax]==2.6.0 - # pip uninstall -y tensorflow - # pip install tensorflow-cpu - pip freeze + # - name: Print dependencies + # run: | + # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + # pip install -U transformer-engine[jax]==2.6.0 + # # pip uninstall -y tensorflow + # # pip install tensorflow-cpu + # pip freeze - - name: Run MaxText Training - run: | - # This command is adapted from your DAG for a single-slice configuration. - cd maxtext && \ - pip install . --no-dependencies + # - name: Run MaxText Training + # run: | + # # This command is adapted from your DAG for a single-slice configuration. + # cd maxtext && \ + # pip install . --no-dependencies - export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 - export TF_FORCE_GPU_ALLOW_GROWTH=true + # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 + # export TF_FORCE_GPU_ALLOW_GROWTH=true - python3 -m MaxText.train MaxText/configs/base.yml \ - steps=2 \ - enable_checkpointing=false \ - attention=cudnn_flash_te \ - dataset_type=synthetic \ - run_name=rbierneni-test-maxtext-gpu \ - base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} + # python3 -m MaxText.train MaxText/configs/base.yml \ + # steps=2 \ + # enable_checkpointing=false \ + # attention=cudnn_flash_te \ + # dataset_type=synthetic \ + # run_name=rbierneni-test-maxtext-gpu \ + # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD maxdiffusion_workload: @@ -78,6 +78,8 @@ jobs: run: | # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 pip install -U transformer-engine[jax]==2.6.0 + pip uninstall -y transformer-engine-cu12 + pip install transformer-engine-cu13 # pip uninstall -y tensorflow # pip install tensorflow-cpu pip freeze From 994d8dcb482aefd0a59b0b629fe2b9a3df3a18f0 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 03:01:37 +0000 Subject: [PATCH 31/63] try with base TE --- .github/workflows/UnitTests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index a43497ba..50977447 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -77,9 +77,9 @@ jobs: - name: Print dependencies run: | # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install -U transformer-engine[jax]==2.6.0 - pip uninstall -y transformer-engine-cu12 - pip install transformer-engine-cu13 + pip install -U transformer-engine==2.6.0 + # pip uninstall -y transformer-engine-cu12 + # pip install transformer-engine-cu13 # pip uninstall -y tensorflow # pip install tensorflow-cpu pip freeze From 6700b5364ae3488759561c3f5365b45f2d355ee7 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 03:02:36 +0000 Subject: [PATCH 32/63] uninstall existing TE --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 50977447..fb32c9a2 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -76,7 +76,7 @@ jobs: - name: Print dependencies run: | - # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 pip install -U transformer-engine==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 From 23d8ca20426b5930d937a2b0009c391accc19148 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 03:18:32 +0000 Subject: [PATCH 33/63] test te 2.6.0 jax --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index fb32c9a2..97c5de55 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -77,7 +77,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install -U transformer-engine==2.6.0 + pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 # pip uninstall -y tensorflow From 0885a52ac7865490f7b02c3a3723d0d51998b001 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 04:01:59 +0000 Subject: [PATCH 34/63] test te 2.6.0 jax --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 97c5de55..ce5d4be2 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev3_gpu + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev4_gpu steps: - name: Checkout Repository uses: actions/checkout@v4 From 22cfa352df04e0d2c28d8581a46e3da013db8d9c Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 04:12:21 +0000 Subject: [PATCH 35/63] remove steps for uninstalling TE --- .github/workflows/UnitTests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index ce5d4be2..69fa4f16 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -76,8 +76,8 @@ jobs: - name: Print dependencies run: | - pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install -U transformer-engine[jax]==2.6.0 + # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + # pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 # pip uninstall -y tensorflow From ab478a1f258cc49c45fc3399f17d1b4fd8b22cc1 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 04:29:17 +0000 Subject: [PATCH 36/63] check host gpu driver --- .github/workflows/UnitTests.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 69fa4f16..14ddf4db 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -74,6 +74,19 @@ jobs: - name: Checkout Repository uses: actions/checkout@v4 + - name: Check Host CUDA and GPU Environment + run: | + echo "--- Checking NVIDIA driver and supported CUDA version ---" + nvidia-smi || echo "nvidia-smi command not found. No GPU or NVIDIA driver detected." + + echo "" + echo "--- Checking for default CUDA toolkit installation ---" + ls -l /usr/local/ | grep cuda || echo "No default CUDA toolkit found in /usr/local/" + + echo "" + echo "--- Checking dynamic linker library path ---" + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-'Not Set'}" + - name: Print dependencies run: | # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 From 7ff2c235db5a165c2960aab08e162a838f3ec280 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 07:32:16 +0000 Subject: [PATCH 37/63] try with cuda 12 on TE cu12 --- .github/workflows/UnitTests.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 14ddf4db..6e92be77 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev4_gpu + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda12_tecu12_gpu steps: - name: Checkout Repository uses: actions/checkout@v4 @@ -89,7 +89,8 @@ jobs: - name: Print dependencies run: | - # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + pip install transformer_engine[jax]==2.6.0 # pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 From ef2041e86993fa0bd03801d9011cb7a368308779 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 07:56:48 +0000 Subject: [PATCH 38/63] try with last tested TE version --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 6e92be77..d83bbf01 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -90,7 +90,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install transformer_engine[jax]==2.6.0 + pip install transformer_engine[jax]==2.4.0 # pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 From ec29e4e055c5de2f13bd96a95800e0450e30e662 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 08:38:18 +0000 Subject: [PATCH 39/63] try with no-cache build image --- .github/workflows/UnitTests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index d83bbf01..a53b892a 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -89,8 +89,8 @@ jobs: - name: Print dependencies run: | - pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install transformer_engine[jax]==2.4.0 + # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + # pip install transformer_engine[jax]==2.4.0 # pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 From ecfacad5bcb02c84ac07c16af22db0feaf39b328 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 08:46:39 +0000 Subject: [PATCH 40/63] update image tag to prevent gke caching --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index a53b892a..22f4d801 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda12_tecu12_gpu + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda13_te2.6.0 steps: - name: Checkout Repository uses: actions/checkout@v4 From af23583580ec639e1163fccd41258bb972603ff2 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 11:12:42 -0700 Subject: [PATCH 41/63] [ DONT MERGE ] Testing Image --- .github/workflows/UnitTests.yml | 145 ++++++-------------------------- 1 file changed, 24 insertions(+), 121 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 22f4d801..6add1b70 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -15,7 +15,9 @@ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -name: Unit Test +# This workflow will run a small FLUX training workload on a GPU runner. + +name: FLUX Workload Training on GPU on: pull_request: @@ -24,135 +26,36 @@ on: workflow_dispatch: jobs: - # maxtext_workload: - # name: "Run MaxText Workload" - # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] - # container: - # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest - # steps: - # - name: Checkout MaxText Repo - # uses: actions/checkout@v4 - # with: - # repository: AI-Hypercomputer/maxtext - # path: maxtext - # ref: rbierneni-test-gpu-run - - # - name: Print dependencies - # run: | - # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - # pip install -U transformer-engine[jax]==2.6.0 - # # pip uninstall -y tensorflow - # # pip install tensorflow-cpu - # pip freeze - - # - name: Run MaxText Training - # run: | - # # This command is adapted from your DAG for a single-slice configuration. - # cd maxtext && \ - # pip install . --no-dependencies - - # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 - # export TF_FORCE_GPU_ALLOW_GROWTH=true - - # python3 -m MaxText.train MaxText/configs/base.yml \ - # steps=2 \ - # enable_checkpointing=false \ - # attention=cudnn_flash_te \ - # dataset_type=synthetic \ - # run_name=rbierneni-test-maxtext-gpu \ - # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} - - # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD - maxdiffusion_workload: - name: "Run MaxDiffusion Workload" - # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + flux_training_workload: + name: "Run FLUX Training Workload" + # IMPORTANT: Replace with the label for your specific GPU runner if different runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda13_te2.6.0 + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1:latest + steps: - name: Checkout Repository uses: actions/checkout@v4 - - name: Check Host CUDA and GPU Environment + - name: Install Dependencies run: | - echo "--- Checking NVIDIA driver and supported CUDA version ---" - nvidia-smi || echo "nvidia-smi command not found. No GPU or NVIDIA driver detected." - - echo "" - echo "--- Checking for default CUDA toolkit installation ---" - ls -l /usr/local/ | grep cuda || echo "No default CUDA toolkit found in /usr/local/" - - echo "" - echo "--- Checking dynamic linker library path ---" - echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-'Not Set'}" - - - name: Print dependencies + pip install -r requirements.txt + pip install --upgrade torch torchvision + # Install the maxdiffusion package to make it available for execution + pip install . + + - name: List Installed Libraries run: | - # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - # pip install transformer_engine[jax]==2.4.0 - # pip install -U transformer-engine[jax]==2.6.0 - # pip uninstall -y transformer-engine-cu12 - # pip install transformer-engine-cu13 - # pip uninstall -y tensorflow - # pip install tensorflow-cpu + echo "--- Installed Python packages ---" pip freeze - - name: Run MaxDiffusion Training + - name: Run FLUX Training + env: + NVTE_FRAMEWORK: jax run: | - # This command is adapted from your DAG for a single-slice configuration. - NVTE_FUSED_ATTN=1 pip install . && \ - python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \ + run_name="flux-ci-test-${{ github.run_id }}" \ + output_dir="/tmp/flux-output/" \ + max_train_steps=5 \ hardware=gpu \ - train_new_unet=true \ - train_text_encoder=false \ - cache_latents_text_encoder_outputs=true \ - per_device_batch_size=1 \ - attention=cudnn_flash_te \ - activations_dtype=bfloat16 \ - weights_dtype=bfloat16 \ - max_train_steps=200 \ - enable_profiler=True \ - run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} - -# jobs: -# build: -# strategy: -# fail-fast: false -# matrix: -# tpu-type: ["v5p-8"] -# name: "TPU test (${{ matrix.tpu-type }})" -# runs-on: ["self-hosted","${{ matrix.tpu-type }}"] -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python 3.12 -# uses: actions/setup-python@v5 -# with: -# python-version: '3.12' -# - name: Install dependencies -# run: | -# pip install -e . -# pip uninstall jax jaxlib libtpu-nightly libtpu -y -# bash setup.sh MODE=stable -# export PATH=$PATH:$HOME/.local/bin -# pip install ruff -# pip install isort -# pip install pytest -# - name: Analysing the code with ruff -# run: | -# ruff check . -# - name: version check -# run: | -# python --version -# pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets -# - name: PyTest -# run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py -# HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x -# add_pull_ready: -# if: github.ref != 'refs/heads/main' -# permissions: -# checks: read -# pull-requests: write -# needs: build -# uses: ./.github/workflows/AddLabel.yml + attention="cudnn_flash_te" \ No newline at end of file From bf38ae492de81100973f6250e48b77d70bdbd62f Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 11:24:36 -0700 Subject: [PATCH 42/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 6add1b70..f6041cd5 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -31,7 +31,7 @@ jobs: # IMPORTANT: Replace with the label for your specific GPU runner if different runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1:latest + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 steps: - name: Checkout Repository @@ -58,4 +58,4 @@ jobs: output_dir="/tmp/flux-output/" \ max_train_steps=5 \ hardware=gpu \ - attention="cudnn_flash_te" \ No newline at end of file + attention="cudnn_flash_te" From db7a3c643b607c00c0ff0fe0cc32cffd75bc4a3d Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 11:40:32 -0700 Subject: [PATCH 43/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index f6041cd5..e5d7c9f7 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -58,4 +58,5 @@ jobs: output_dir="/tmp/flux-output/" \ max_train_steps=5 \ hardware=gpu \ - attention="cudnn_flash_te" + attention="cudnn_flash_te" \ + enable_checkpointing=False \ From 9bdf6062ff76e62fc1e9d4a86d28a71f9ac998d4 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 11:58:41 -0700 Subject: [PATCH 44/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index e5d7c9f7..8a13f7e3 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -59,4 +59,3 @@ jobs: max_train_steps=5 \ hardware=gpu \ attention="cudnn_flash_te" \ - enable_checkpointing=False \ From a506777e8cec5734614ba45a8bd86678020d7ac0 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 12:02:17 -0700 Subject: [PATCH 45/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 8a13f7e3..6e8b6eb4 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -49,6 +49,11 @@ jobs: echo "--- Installed Python packages ---" pip freeze + - name: Hugging Face Login + uses: huggingface/login@v1 + with: + token: ${{ secrets.HUGGINGFACE_TOKEN }} + - name: Run FLUX Training env: NVTE_FRAMEWORK: jax From d8cdbd5c812d1fd742c3ecff12394bf9694bdfae Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 12:05:38 -0700 Subject: [PATCH 46/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 6e8b6eb4..e6677441 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -50,9 +50,7 @@ jobs: pip freeze - name: Hugging Face Login - uses: huggingface/login@v1 - with: - token: ${{ secrets.HUGGINGFACE_TOKEN }} + run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }} - name: Run FLUX Training env: From 70beed7c83f9f31bd00b6cd3ef632ddc4bb33de8 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 12:24:06 -0700 Subject: [PATCH 47/63] Changing the return type --- src/maxdiffusion/models/attention_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 5df5f334..9e32661c 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -395,7 +395,7 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m check_rep=False, ) def wrap_flash_attention(query, key, value): - return jax.vmap(dpa_layer)(query, key, value, mask=None) + return dpa_layer(query, key, value, mask=None) out = wrap_flash_attention(query, key, value) return _reshape_data_from_cudnn_flash(out) From 03d4bae70aef0a23672568e36856f0d0d2992f8a Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 12:49:08 -0700 Subject: [PATCH 48/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index e6677441..e0b13bbf 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -15,9 +15,9 @@ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -# This workflow will run a small FLUX training workload on a GPU runner. +# This workflow will run a small SDXL training workload on a GPU runner. -name: FLUX Workload Training on GPU +name: SDXL Workload Training on GPU on: pull_request: @@ -26,8 +26,8 @@ on: workflow_dispatch: jobs: - flux_training_workload: - name: "Run FLUX Training Workload" + sdxl_training_workload: + name: "Run SDXL Training Workload" # IMPORTANT: Replace with the label for your specific GPU runner if different runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: @@ -48,17 +48,22 @@ jobs: run: | echo "--- Installed Python packages ---" pip freeze - + - name: Hugging Face Login run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }} - - name: Run FLUX Training + - name: Run SDXL Training env: NVTE_FRAMEWORK: jax run: | - python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \ - run_name="flux-ci-test-${{ github.run_id }}" \ - output_dir="/tmp/flux-output/" \ + python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + run_name="sdxl-ci-test-${{ github.run_id }}" \ + output_dir="/tmp/sdxl-output/" \ max_train_steps=5 \ hardware=gpu \ attention="cudnn_flash_te" \ + resolution=512 \ + per_device_batch_size=1 \ + train_new_unet=true \ + train_text_encoder=false \ + cache_latents_text_encoder_outputs=true From 0b2fdc28be7957d69e19440ab5edd1b9e874fd3b Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 13:17:17 -0700 Subject: [PATCH 49/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index e0b13bbf..92faa4ac 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -40,6 +40,10 @@ jobs: - name: Install Dependencies run: | pip install -r requirements.txt + # Uninstall the full tensorflow package to prevent GPU conflicts + pip uninstall -y tensorflow + # Install the CPU-only version of tensorflow + pip install tensorflow-cpu pip install --upgrade torch torchvision # Install the maxdiffusion package to make it available for execution pip install . From a93fbff042c623295e0dda2710a1565f3f09da07 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 13:43:09 -0700 Subject: [PATCH 50/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 92faa4ac..263c2c50 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -59,6 +59,7 @@ jobs: - name: Run SDXL Training env: NVTE_FRAMEWORK: jax + TF_FORCE_GPU_ALLOW_GROWTH: "true" run: | python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ run_name="sdxl-ci-test-${{ github.run_id }}" \ From ad1292ba830188d9efb8794171ed08a4e9ed5116 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 13:47:40 -0700 Subject: [PATCH 51/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 263c2c50..e8266e07 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -17,6 +17,8 @@ # This workflow will run a small SDXL training workload on a GPU runner. +# This workflow will run a small SDXL training workload on a GPU runner. + name: SDXL Workload Training on GPU on: @@ -31,21 +33,39 @@ jobs: # IMPORTANT: Replace with the label for your specific GPU runner if different runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1:latest steps: + - name: Create and Activate Swap File + run: | + echo "--- Verifying free space before changes ---" + free -h + echo "---" + echo "Creating and activating a 64GB swap file..." + # Deactivate any existing swap to be safe + sudo swapoff -a + # Allocate a 64GB file + sudo fallocate -l 64G /swapfile + # Set the correct permissions + sudo chmod 600 /swapfile + # Format the file as swap + sudo mkswap /swapfile + # Activate the swap file + sudo swapon /swapfile + echo "--- Swap file is now active ---" + sudo swapon --show + echo "--- Verifying free space after changes ---" + free -h + - name: Checkout Repository uses: actions/checkout@v4 - name: Install Dependencies run: | - pip install -r requirements.txt - # Uninstall the full tensorflow package to prevent GPU conflicts pip uninstall -y tensorflow - # Install the CPU-only version of tensorflow pip install tensorflow-cpu + pip install -r requirements.txt pip install --upgrade torch torchvision - # Install the maxdiffusion package to make it available for execution pip install . - name: List Installed Libraries From 80711cccb1c3f5613900cf5dc6592d056ac29729 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 13:48:24 -0700 Subject: [PATCH 52/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index e8266e07..bb0844a7 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -33,7 +33,7 @@ jobs: # IMPORTANT: Replace with the label for your specific GPU runner if different runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1:latest + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 steps: - name: Create and Activate Swap File From 02da45aeae4470eeb00209510e42c7645b076466 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 13:52:26 -0700 Subject: [PATCH 53/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index bb0844a7..1f31119f 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -16,7 +16,6 @@ # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python # This workflow will run a small SDXL training workload on a GPU runner. - # This workflow will run a small SDXL training workload on a GPU runner. name: SDXL Workload Training on GPU @@ -33,29 +32,15 @@ jobs: # IMPORTANT: Replace with the label for your specific GPU runner if different runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1:latest steps: - - name: Create and Activate Swap File + - name: Verify Environment run: | - echo "--- Verifying free space before changes ---" - free -h - echo "---" - echo "Creating and activating a 64GB swap file..." - # Deactivate any existing swap to be safe - sudo swapoff -a - # Allocate a 64GB file - sudo fallocate -l 64G /swapfile - # Set the correct permissions - sudo chmod 600 /swapfile - # Format the file as swap - sudo mkswap /swapfile - # Activate the swap file - sudo swapon /swapfile - echo "--- Swap file is now active ---" - sudo swapon --show - echo "--- Verifying free space after changes ---" + echo "--- Verifying free space ---" free -h + echo "--- Verifying shared memory size ---" + df -h /dev/shm - name: Checkout Repository uses: actions/checkout@v4 From 4aa60d159b3663c836de2ad522bc79d94c35c45c Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 13:55:02 -0700 Subject: [PATCH 54/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 1f31119f..b29342f2 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -32,7 +32,7 @@ jobs: # IMPORTANT: Replace with the label for your specific GPU runner if different runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1:latest + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 steps: - name: Verify Environment From 18b8e2b866a592e34e3f8e3ce135b2baa7524919 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 14:27:43 -0700 Subject: [PATCH 55/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index b29342f2..0072658d 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -47,9 +47,9 @@ jobs: - name: Install Dependencies run: | + pip install -r requirements.txt pip uninstall -y tensorflow pip install tensorflow-cpu - pip install -r requirements.txt pip install --upgrade torch torchvision pip install . From 9172655ba63a569e8c0757b7818dc3457c768556 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 14:38:39 -0700 Subject: [PATCH 56/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 125 ++++++++++++++++++++++++-------- 1 file changed, 93 insertions(+), 32 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 0072658d..a11da37d 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -16,9 +16,11 @@ # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python # This workflow will run a small SDXL training workload on a GPU runner. -# This workflow will run a small SDXL training workload on a GPU runner. -name: SDXL Workload Training on GPU +# This workflow will run a small MaxText training workload on a GPU runner +# by checking out the MaxText repo inside the MaxDiffusion environment. + +name: MaxText Workload on MaxDiffusion Runner on: pull_request: @@ -27,53 +29,112 @@ on: workflow_dispatch: jobs: - sdxl_training_workload: - name: "Run SDXL Training Workload" - # IMPORTANT: Replace with the label for your specific GPU runner if different + maxtext_training_workload: + name: "Run MaxText Training Workload" runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 + # Using the MaxDiffusion container as requested + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1:latest steps: - - name: Verify Environment - run: | - echo "--- Verifying free space ---" - free -h - echo "--- Verifying shared memory size ---" - df -h /dev/shm - - - name: Checkout Repository + - name: Checkout MaxText Repository uses: actions/checkout@v4 + with: + repository: 'AI-Hypercomputer/maxtext' + ref: 'main' + path: 'maxtext' # Clone it into a 'maxtext' subdirectory - name: Install Dependencies + working-directory: ./maxtext # Run all subsequent commands inside the new directory run: | - pip install -r requirements.txt + # Uninstall full tensorflow to prevent GPU conflicts with JAX pip uninstall -y tensorflow + # Install the CPU-only version for data loading pip install tensorflow-cpu - pip install --upgrade torch torchvision + # Install MaxText's dependencies + pip install -r requirements.txt + # Install the MaxText package itself pip install . - + - name: List Installed Libraries + working-directory: ./maxtext run: | echo "--- Installed Python packages ---" pip freeze - - - name: Hugging Face Login - run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }} - - name: Run SDXL Training + - name: Run MaxText Training + working-directory: ./maxtext env: + # Set the correct framework for Transformer Engine NVTE_FRAMEWORK: jax + # Prevent TensorFlow from grabbing all GPU memory TF_FORCE_GPU_ALLOW_GROWTH: "true" run: | - python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ - run_name="sdxl-ci-test-${{ github.run_id }}" \ - output_dir="/tmp/sdxl-output/" \ - max_train_steps=5 \ - hardware=gpu \ - attention="cudnn_flash_te" \ - resolution=512 \ - per_device_batch_size=1 \ - train_new_unet=true \ - train_text_encoder=false \ - cache_latents_text_encoder_outputs=true + # Run the main training script with a base configuration + python MaxText/train.py MaxText/configs/base.yml \ + run_name="maxtext-ci-test-${{ github.run_id }}" \ + steps=5 \ + enable_checkpointing=false \ + attention='cudnn_flash_te' \ + dataset_type='synthetic' + + +# name: SDXL Workload Training on GPU + +# on: +# pull_request: +# push: +# branches: [ "main" ] +# workflow_dispatch: + +# jobs: +# sdxl_training_workload: +# name: "Run SDXL Training Workload" +# # IMPORTANT: Replace with the label for your specific GPU runner if different +# runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] +# container: +# image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 + +# steps: +# - name: Verify Environment +# run: | +# echo "--- Verifying free space ---" +# free -h +# echo "--- Verifying shared memory size ---" +# df -h /dev/shm + +# - name: Checkout Repository +# uses: actions/checkout@v4 + +# - name: Install Dependencies +# run: | +# pip install -r requirements.txt +# pip uninstall -y tensorflow +# pip install tensorflow-cpu +# pip install --upgrade torch torchvision +# pip install . + +# - name: List Installed Libraries +# run: | +# echo "--- Installed Python packages ---" +# pip freeze + +# - name: Hugging Face Login +# run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }} + +# - name: Run SDXL Training +# env: +# NVTE_FRAMEWORK: jax +# TF_FORCE_GPU_ALLOW_GROWTH: "true" +# run: | +# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ +# run_name="sdxl-ci-test-${{ github.run_id }}" \ +# output_dir="/tmp/sdxl-output/" \ +# max_train_steps=5 \ +# hardware=gpu \ +# attention="cudnn_flash_te" \ +# resolution=512 \ +# per_device_batch_size=1 \ +# train_new_unet=true \ +# train_text_encoder=false \ +# cache_latents_text_encoder_outputs=true From 6de4ff22c7156367a76889d8517f6f78cd73161a Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 14:40:13 -0700 Subject: [PATCH 57/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index a11da37d..6ac74548 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -34,7 +34,7 @@ jobs: runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: # Using the MaxDiffusion container as requested - image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1:latest + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 steps: - name: Checkout MaxText Repository From aaf1dc919ccf885ba8fcf6c00033df2cd846fa1e Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 15:16:29 -0700 Subject: [PATCH 58/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 42 +++++---------------------------- 1 file changed, 6 insertions(+), 36 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 6ac74548..e5790e85 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -15,12 +15,11 @@ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -# This workflow will run a small SDXL training workload on a GPU runner. # This workflow will run a small MaxText training workload on a GPU runner -# by checking out the MaxText repo inside the MaxDiffusion environment. +# using a custom Docker image with all dependencies pre-installed. -name: MaxText Workload on MaxDiffusion Runner +name: MaxText Custom Image Workload on: pull_request: @@ -33,52 +32,23 @@ jobs: name: "Run MaxText Training Workload" runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - # Using the MaxDiffusion container as requested - image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 + # Use your newly built custom image + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest steps: - - name: Checkout MaxText Repository - uses: actions/checkout@v4 - with: - repository: 'AI-Hypercomputer/maxtext' - ref: 'main' - path: 'maxtext' # Clone it into a 'maxtext' subdirectory - - - name: Install Dependencies - working-directory: ./maxtext # Run all subsequent commands inside the new directory - run: | - # Uninstall full tensorflow to prevent GPU conflicts with JAX - pip uninstall -y tensorflow - # Install the CPU-only version for data loading - pip install tensorflow-cpu - # Install MaxText's dependencies - pip install -r requirements.txt - # Install the MaxText package itself - pip install . - - - name: List Installed Libraries - working-directory: ./maxtext - run: | - echo "--- Installed Python packages ---" - pip freeze - - name: Run MaxText Training - working-directory: ./maxtext env: - # Set the correct framework for Transformer Engine NVTE_FRAMEWORK: jax - # Prevent TensorFlow from grabbing all GPU memory TF_FORCE_GPU_ALLOW_GROWTH: "true" run: | - # Run the main training script with a base configuration + # The working directory is /deps, so this path is correct. python MaxText/train.py MaxText/configs/base.yml \ run_name="maxtext-ci-test-${{ github.run_id }}" \ steps=5 \ enable_checkpointing=false \ - attention='cudnn_flash_te' \ + attention='cudnet_flash_te' \ dataset_type='synthetic' - # name: SDXL Workload Training on GPU # on: From 2e1acd44ca453210c61d2357736775576fecf232 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 15:19:38 -0700 Subject: [PATCH 59/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index e5790e85..db06e9de 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -19,6 +19,9 @@ # This workflow will run a small MaxText training workload on a GPU runner # using a custom Docker image with all dependencies pre-installed. +# This workflow will run a small MaxText training workload on a GPU runner +# using a custom Docker image with all code and dependencies pre-installed. + name: MaxText Custom Image Workload on: @@ -32,21 +35,23 @@ jobs: name: "Run MaxText Training Workload" runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - # Use your newly built custom image + # Use your custom image which contains the source code and dependencies. image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest steps: - name: Run MaxText Training + # The Docker image's working directory is /deps, but the code is in /deps/src. + working-directory: /deps/src env: NVTE_FRAMEWORK: jax TF_FORCE_GPU_ALLOW_GROWTH: "true" run: | - # The working directory is /deps, so this path is correct. + # Run the main training script from the /deps/src directory. python MaxText/train.py MaxText/configs/base.yml \ run_name="maxtext-ci-test-${{ github.run_id }}" \ steps=5 \ enable_checkpointing=false \ - attention='cudnet_flash_te' \ + attention='cudnn_flash_te' \ dataset_type='synthetic' # name: SDXL Workload Training on GPU From 525b9ef23d832c47d42fff1ab7704370b882cf80 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 15:56:04 -0700 Subject: [PATCH 60/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index db06e9de..6d12a893 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -22,7 +22,9 @@ # This workflow will run a small MaxText training workload on a GPU runner # using a custom Docker image with all code and dependencies pre-installed. -name: MaxText Custom Image Workload +# This workflow runs MaxText training with a pinned version of Transformer Engine. + +name: MaxText Custom Image with Pinned TE on: pull_request: @@ -35,18 +37,22 @@ jobs: name: "Run MaxText Training Workload" runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - # Use your custom image which contains the source code and dependencies. image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest steps: + - name: Pin Transformer Engine Version + run: | + echo "Uninstalling any existing Transformer Engine..." + pip uninstall -y transformer-engine + echo "Installing Transformer Engine version 2.6.0..." + pip install transformer-engine[jax]==2.6.0 + - name: Run MaxText Training - # The Docker image's working directory is /deps, but the code is in /deps/src. working-directory: /deps/src env: NVTE_FRAMEWORK: jax TF_FORCE_GPU_ALLOW_GROWTH: "true" run: | - # Run the main training script from the /deps/src directory. python MaxText/train.py MaxText/configs/base.yml \ run_name="maxtext-ci-test-${{ github.run_id }}" \ steps=5 \ From a918ce860065149962fefe06992a489b490fd924 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 16:09:58 -0700 Subject: [PATCH 61/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 6d12a893..2097b05b 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -40,13 +40,6 @@ jobs: image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest steps: - - name: Pin Transformer Engine Version - run: | - echo "Uninstalling any existing Transformer Engine..." - pip uninstall -y transformer-engine - echo "Installing Transformer Engine version 2.6.0..." - pip install transformer-engine[jax]==2.6.0 - - name: Run MaxText Training working-directory: /deps/src env: From 5c3a4040d295a02ae90b102a7a2c4175e2f4f360 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 18:36:40 -0700 Subject: [PATCH 62/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2097b05b..a5e93547 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -40,6 +40,10 @@ jobs: image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest steps: + - name: List Installed Libraries + run: | + echo "--- Installed Python packages ---" + pip freeze - name: Run MaxText Training working-directory: /deps/src env: From e9809b6c10b3e3d454a29c923dfabdc321952d44 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Tue, 30 Sep 2025 19:36:27 -0700 Subject: [PATCH 63/63] Update UnitTests.yml --- .github/workflows/UnitTests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index a5e93547..30de3722 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -49,6 +49,7 @@ jobs: env: NVTE_FRAMEWORK: jax TF_FORCE_GPU_ALLOW_GROWTH: "true" + NVTE_FUSED_ATTN: 1 run: | python MaxText/train.py MaxText/configs/base.yml \ run_name="maxtext-ci-test-${{ github.run_id }}" \