diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2c588b43..30de3722 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -15,54 +15,105 @@ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -name: Unit Test + +# This workflow will run a small MaxText training workload on a GPU runner +# using a custom Docker image with all dependencies pre-installed. + +# This workflow will run a small MaxText training workload on a GPU runner +# using a custom Docker image with all code and dependencies pre-installed. + +# This workflow runs MaxText training with a pinned version of Transformer Engine. + +name: MaxText Custom Image with Pinned TE on: pull_request: push: branches: [ "main" ] workflow_dispatch: - schedule: - # Run the job every 12 hours - - cron: '0 */12 * * *' jobs: - build: - strategy: - fail-fast: false - matrix: - tpu-type: ["v5p-8"] - name: "TPU test (${{ matrix.tpu-type }})" - runs-on: ["self-hosted","${{ matrix.tpu-type }}"] + maxtext_training_workload: + name: "Run MaxText Training Workload" + runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] + container: + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest + steps: - - uses: actions/checkout@v4 - - name: Set up Python 3.12 - uses: actions/setup-python@v5 - with: - python-version: '3.12' - - name: Install dependencies - run: | - pip install -e . - pip uninstall jax jaxlib libtpu-nightly libtpu -y - bash setup.sh MODE=stable - export PATH=$PATH:$HOME/.local/bin - pip install ruff - pip install isort - pip install pytest - - name: Analysing the code with ruff - run: | - ruff check . - - name: version check - run: | - python --version - pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets - - name: PyTest - run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x -# add_pull_ready: -# if: github.ref != 'refs/heads/main' -# permissions: -# checks: read -# pull-requests: write -# needs: build -# uses: ./.github/workflows/AddLabel.yml + - name: List Installed Libraries + run: | + echo "--- Installed Python packages ---" + pip freeze + - name: Run MaxText Training + working-directory: /deps/src + env: + NVTE_FRAMEWORK: jax + TF_FORCE_GPU_ALLOW_GROWTH: "true" + NVTE_FUSED_ATTN: 1 + run: | + python MaxText/train.py MaxText/configs/base.yml \ + run_name="maxtext-ci-test-${{ github.run_id }}" \ + steps=5 \ + enable_checkpointing=false \ + attention='cudnn_flash_te' \ + dataset_type='synthetic' + +# name: SDXL Workload Training on GPU + +# on: +# pull_request: +# push: +# branches: [ "main" ] +# workflow_dispatch: + +# jobs: +# sdxl_training_workload: +# name: "Run SDXL Training Workload" +# # IMPORTANT: Replace with the label for your specific GPU runner if different +# runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] +# container: +# image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 + +# steps: +# - name: Verify Environment +# run: | +# echo "--- Verifying free space ---" +# free -h +# echo "--- Verifying shared memory size ---" +# df -h /dev/shm + +# - name: Checkout Repository +# uses: actions/checkout@v4 + +# - name: Install Dependencies +# run: | +# pip install -r requirements.txt +# pip uninstall -y tensorflow +# pip install tensorflow-cpu +# pip install --upgrade torch torchvision +# pip install . + +# - name: List Installed Libraries +# run: | +# echo "--- Installed Python packages ---" +# pip freeze + +# - name: Hugging Face Login +# run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }} + +# - name: Run SDXL Training +# env: +# NVTE_FRAMEWORK: jax +# TF_FORCE_GPU_ALLOW_GROWTH: "true" +# run: | +# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ +# run_name="sdxl-ci-test-${{ github.run_id }}" \ +# output_dir="/tmp/sdxl-output/" \ +# max_train_steps=5 \ +# hardware=gpu \ +# attention="cudnn_flash_te" \ +# resolution=512 \ +# per_device_batch_size=1 \ +# train_new_unet=true \ +# train_text_encoder=false \ +# cache_latents_text_encoder_outputs=true diff --git a/maxdiffusion_jax_ai_image_tpu.Dockerfile b/maxdiffusion_jax_ai_image_tpu.Dockerfile index cab50fee..301f9b88 100644 --- a/maxdiffusion_jax_ai_image_tpu.Dockerfile +++ b/maxdiffusion_jax_ai_image_tpu.Dockerfile @@ -19,4 +19,4 @@ COPY . . RUN pip install -r /deps/requirements_with_jax_ai_image.txt # Run the script available in JAX-AI-Image base image to generate the manifest file -RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file +RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 5df5f334..9e32661c 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -395,7 +395,7 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m check_rep=False, ) def wrap_flash_attention(query, key, value): - return jax.vmap(dpa_layer)(query, key, value, mask=None) + return dpa_layer(query, key, value, mask=None) out = wrap_flash_attention(query, key, value) return _reshape_data_from_cudnn_flash(out)