Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
6e0898a
Update entrypoint for jaii
Rohan-Bierneni Sep 25, 2025
c50ab1b
remove self-hosted tag
Rohan-Bierneni Sep 26, 2025
1f22cf5
wrong tag name
Rohan-Bierneni Sep 26, 2025
2deaa7e
Update with command for gpu
Rohan-Bierneni Sep 26, 2025
d47a6eb
remove pip install .
Rohan-Bierneni Sep 26, 2025
a6579ec
Check environment
Rohan-Bierneni Sep 26, 2025
cc5d322
Update to right image
Rohan-Bierneni Sep 26, 2025
26a03c0
point to allowed bucket
Rohan-Bierneni Sep 26, 2025
9236f6f
Install entire TE
Rohan-Bierneni Sep 26, 2025
4e0dd7f
uninstall all TE deps
Rohan-Bierneni Sep 26, 2025
0888316
install TE jax
Rohan-Bierneni Sep 26, 2025
09c5d6a
fix pip instal typo
Rohan-Bierneni Sep 26, 2025
846e176
Install TE for pytorch and jax
Rohan-Bierneni Sep 26, 2025
496f67c
Comment out tflop calc
Rohan-Bierneni Sep 26, 2025
37524bb
Test with dot_product attention
Rohan-Bierneni Sep 26, 2025
64f30f3
Test if maxtext has same gpu issue
Rohan-Bierneni Sep 26, 2025
53d1d22
change to pip install maxtext package
Rohan-Bierneni Sep 26, 2025
fdcec4c
Install with no dependencies
Rohan-Bierneni Sep 26, 2025
cd9daaf
Use custom maxtext branch
Rohan-Bierneni Sep 26, 2025
e0615dc
use synthetic data
Rohan-Bierneni Sep 26, 2025
323654a
Try with TE flash
Rohan-Bierneni Sep 26, 2025
2313d54
Try with older TE
Rohan-Bierneni Sep 26, 2025
2eafa93
Typo in TE install
Rohan-Bierneni Sep 26, 2025
a1da601
Use TE 2.5.0
Rohan-Bierneni Sep 26, 2025
dfe8fc5
Use TE 2.6.0
Rohan-Bierneni Sep 26, 2025
4243d53
Test with tensorflow-cpu
Rohan-Bierneni Sep 26, 2025
d37e0dd
Test with new cuda13 images and TE 2.6.0
Rohan-Bierneni Sep 30, 2025
c247bee
Update to right image
Rohan-Bierneni Sep 30, 2025
67141da
uninstall cuda12 TE as well
Rohan-Bierneni Sep 30, 2025
4ac11bb
Test with te-cu13 package
Rohan-Bierneni Sep 30, 2025
994d8dc
try with base TE
Rohan-Bierneni Sep 30, 2025
6700b53
uninstall existing TE
Rohan-Bierneni Sep 30, 2025
23d8ca2
test te 2.6.0 jax
Rohan-Bierneni Sep 30, 2025
0885a52
test te 2.6.0 jax
Rohan-Bierneni Sep 30, 2025
22cfa35
remove steps for uninstalling TE
Rohan-Bierneni Sep 30, 2025
ab478a1
check host gpu driver
Rohan-Bierneni Sep 30, 2025
7ff2c23
try with cuda 12 on TE cu12
Rohan-Bierneni Sep 30, 2025
ef2041e
try with last tested TE version
Rohan-Bierneni Sep 30, 2025
ec29e4e
try with no-cache build image
Rohan-Bierneni Sep 30, 2025
ecfacad
update image tag to prevent gke caching
Rohan-Bierneni Sep 30, 2025
af23583
[ DONT MERGE ] Testing Image
parambole Sep 30, 2025
bf38ae4
Update UnitTests.yml
parambole Sep 30, 2025
db7a3c6
Update UnitTests.yml
parambole Sep 30, 2025
9bdf606
Update UnitTests.yml
parambole Sep 30, 2025
a506777
Update UnitTests.yml
parambole Sep 30, 2025
d8cdbd5
Update UnitTests.yml
parambole Sep 30, 2025
70beed7
Changing the return type
parambole Sep 30, 2025
03d4bae
Update UnitTests.yml
parambole Sep 30, 2025
0b2fdc2
Update UnitTests.yml
parambole Sep 30, 2025
a93fbff
Update UnitTests.yml
parambole Sep 30, 2025
ad1292b
Update UnitTests.yml
parambole Sep 30, 2025
80711cc
Update UnitTests.yml
parambole Sep 30, 2025
02da45a
Update UnitTests.yml
parambole Sep 30, 2025
4aa60d1
Update UnitTests.yml
parambole Sep 30, 2025
18b8e2b
Update UnitTests.yml
parambole Sep 30, 2025
9172655
Update UnitTests.yml
parambole Sep 30, 2025
6de4ff2
Update UnitTests.yml
parambole Sep 30, 2025
aaf1dc9
Update UnitTests.yml
parambole Sep 30, 2025
2e1acd4
Update UnitTests.yml
parambole Sep 30, 2025
525b9ef
Update UnitTests.yml
parambole Sep 30, 2025
a918ce8
Update UnitTests.yml
parambole Sep 30, 2025
5c3a404
Update UnitTests.yml
parambole Oct 1, 2025
e9809b6
Update UnitTests.yml
parambole Oct 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 93 additions & 42 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,54 +15,105 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Unit Test

# This workflow will run a small MaxText training workload on a GPU runner
# using a custom Docker image with all dependencies pre-installed.

# This workflow will run a small MaxText training workload on a GPU runner
# using a custom Docker image with all code and dependencies pre-installed.

# This workflow runs MaxText training with a pinned version of Transformer Engine.

name: MaxText Custom Image with Pinned TE

on:
pull_request:
push:
branches: [ "main" ]
workflow_dispatch:
schedule:
# Run the job every 12 hours
- cron: '0 */12 * * *'

jobs:
build:
strategy:
fail-fast: false
matrix:
tpu-type: ["v5p-8"]
name: "TPU test (${{ matrix.tpu-type }})"
runs-on: ["self-hosted","${{ matrix.tpu-type }}"]
maxtext_training_workload:
name: "Run MaxText Training Workload"
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
container:
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install dependencies
run: |
pip install -e .
pip uninstall jax jaxlib libtpu-nightly libtpu -y
bash setup.sh MODE=stable
export PATH=$PATH:$HOME/.local/bin
pip install ruff
pip install isort
pip install pytest
- name: Analysing the code with ruff
run: |
ruff check .
- name: version check
run: |
python --version
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
- name: PyTest
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
# checks: read
# pull-requests: write
# needs: build
# uses: ./.github/workflows/AddLabel.yml
- name: List Installed Libraries
run: |
echo "--- Installed Python packages ---"
pip freeze
- name: Run MaxText Training
working-directory: /deps/src
env:
NVTE_FRAMEWORK: jax
TF_FORCE_GPU_ALLOW_GROWTH: "true"
NVTE_FUSED_ATTN: 1
run: |
python MaxText/train.py MaxText/configs/base.yml \
run_name="maxtext-ci-test-${{ github.run_id }}" \
steps=5 \
enable_checkpointing=false \
attention='cudnn_flash_te' \
dataset_type='synthetic'

# name: SDXL Workload Training on GPU

# on:
# pull_request:
# push:
# branches: [ "main" ]
# workflow_dispatch:

# jobs:
# sdxl_training_workload:
# name: "Run SDXL Training Workload"
# # IMPORTANT: Replace with the label for your specific GPU runner if different
# runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
# container:
# image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1

# steps:
# - name: Verify Environment
# run: |
# echo "--- Verifying free space ---"
# free -h
# echo "--- Verifying shared memory size ---"
# df -h /dev/shm

# - name: Checkout Repository
# uses: actions/checkout@v4

# - name: Install Dependencies
# run: |
# pip install -r requirements.txt
# pip uninstall -y tensorflow
# pip install tensorflow-cpu
# pip install --upgrade torch torchvision
# pip install .

# - name: List Installed Libraries
# run: |
# echo "--- Installed Python packages ---"
# pip freeze

# - name: Hugging Face Login
# run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}

# - name: Run SDXL Training
# env:
# NVTE_FRAMEWORK: jax
# TF_FORCE_GPU_ALLOW_GROWTH: "true"
# run: |
# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
# run_name="sdxl-ci-test-${{ github.run_id }}" \
# output_dir="/tmp/sdxl-output/" \
# max_train_steps=5 \
# hardware=gpu \
# attention="cudnn_flash_te" \
# resolution=512 \
# per_device_batch_size=1 \
# train_new_unet=true \
# train_text_encoder=false \
# cache_latents_text_encoder_outputs=true
2 changes: 1 addition & 1 deletion maxdiffusion_jax_ai_image_tpu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ COPY . .
RUN pip install -r /deps/requirements_with_jax_ai_image.txt

# Run the script available in JAX-AI-Image base image to generate the manifest file
RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH
RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH
2 changes: 1 addition & 1 deletion src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m
check_rep=False,
)
def wrap_flash_attention(query, key, value):
return jax.vmap(dpa_layer)(query, key, value, mask=None)
return dpa_layer(query, key, value, mask=None)

out = wrap_flash_attention(query, key, value)
return _reshape_data_from_cudnn_flash(out)
Expand Down
Loading