Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion .github/actions/gke-xpk/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ runs:

if [ $? -ne 0 ]; then
echo "The JobSet ${WORKLOAD_NAME} on ${{ inputs.GKE_CLUSTER }} did not complete as expected "
echo "XPK_EXIT_CODE=1" >> ${GITHUB_ENV}
exit 1
fi

Expand All @@ -262,11 +263,12 @@ runs:
ALL_EXIT_CODES=$(( ALL_EXIT_CODES + POD_EXIT_CODE ))
done

echo "XPK_EXIT_CODE=${ALL_EXIT_CODES}" >> ${GITHUB_ENV}
if [ ${ALL_EXIT_CODES} -gt 0 ]; then
exit 1
fi
exit 0

- name: Clean up JobSet from cluster
shell: bash -x -u {0}
if: ${{ always() }}
Expand All @@ -291,3 +293,38 @@ runs:
if: ${{ always() }}
run: |
sudo rm -rf ${WORKLOAD_NAME}

- name: Generate sitrep
id: sitrep
shell: bash -x -e {0}
if: ${{ always() }}
run: |
source .github/workflows/scripts/to_json.sh
badge_label="${{ matrix.test }}"

summary="${{ inputs.WORKLOAD_NAME_PREFIX }}"
outcome=success
badge_label="${{ inputs.WORKLOAD_NAME_PREFIX }}"
badge_color=brightgreen

if [ "${XPK_EXIT_CODE}" -gt 0 ]; then
badge_color=red
outcome=failed
summary+=": fail"
else
summary+=": pass"
fi

to_json summary \
badge_label \
badge_color \
outcome | \
tee sitrep.json

- name: Upload sitrep to GitHub Actions from runner
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: ${{ inputs.WORKLOAD_NAME_PREFIX }}-sitrep
path: |
sitrep.json
2 changes: 1 addition & 1 deletion .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ RUN mkdir -p /builder/extra-targets && \
--src-path-xla ${SRC_PATH_XLA} \
--sm all \
--clean \
--release \
${EXTRA_BUILD_JAX_ARGS}

## Transformer engine: check out source and build wheel
Expand Down Expand Up @@ -97,7 +98,6 @@ ENV BUILD_DATE=${BUILD_DATE}
# The following environment variables tune performance
ENV XLA_FLAGS=""
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true"
ENV NCCL_NVLS_ENABLE=0

COPY --from=builder ${BUILD_PATH_JAXLIB} ${BUILD_PATH_JAXLIB}
COPY --from=builder ${SRC_PATH_JAX} ${SRC_PATH_JAX}
Expand Down
13 changes: 12 additions & 1 deletion .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ INSTALL=1
SRC_PATH_JAX="/opt/jax"
SRC_PATH_XLA="/opt/xla"

args=$(getopt -o h --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,cpu-arch:,debug,extra-targets:,extra-target-dest:,no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- "$@")
args=$(getopt -o h,r --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,release,cpu-arch:,debug,extra-targets:,extra-target-dest:,no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- "$@")
if [[ $? -ne 0 ]]; then
exit 1
fi
Expand Down Expand Up @@ -135,6 +135,10 @@ while [ : ]; do
EXTRA_TARGET_DEST="$2"
shift 2
;;
-r | --release)
IS_RELEASE=1
shift 1
;;
-h | --help)
usage 1
;;
Expand Down Expand Up @@ -225,6 +229,7 @@ print_var INSTALL
print_var PYTHON_VERSION
print_var SRC_PATH_JAX
print_var SRC_PATH_XLA
print_var IS_RELEASE

echo "=================================================="

Expand Down Expand Up @@ -268,6 +273,12 @@ for component in jaxlib "jax-cuda${CUDA_MAJOR_VERSION}-pjrt" "jax-cuda${CUDA_MAJ
# version, so nvidia-*-cu12 wheels disappear from the lock file
sed -i "s|^${component}.*$|${component} @ file://${BUILD_PATH_JAXLIB}/${component//-/_}|" build/requirements.in
done

if [[ "${IS_RELEASE}" == "1" ]]; then
jaxlib_version=$(pip show jaxlib | grep Version | tr ':' '\n' | tail -1)
sed -i "s| f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',| f'jaxlib>=0.5.0',|" /opt/jax/setup.py
fi

# Bazel args to avoid cache invalidation
BAZEL_ARGS=(
--config=cuda_libraries_from_stubs
Expand Down
21 changes: 19 additions & 2 deletions .github/container/build-te.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ if [[ "$SM" == "all" ]]; then
SM_LIST=$(default_compute_capabilities)
elif [[ "$SM" == "local" ]]; then
SM_LIST=$("${SCRIPT_DIR}/local_cuda_arch")
if [[ -z "${SM_LIST}" ]]; then
echo "Could not determine the local GPU architecture."
echo "You should pass --sm when compiling on a machine without GPUs."
nvidia-smi || true
exit 1
fi
else
SM_LIST=${SM}
fi
Expand Down Expand Up @@ -131,8 +137,19 @@ export NVTE_FRAMEWORK=jax
export XLA_HOME=${SRC_PATH_XLA}

pushd ${SRC_PATH_TE}
# Install required packages that were removed in https://github.com/NVIDIA/TransformerEngine/pull/1852
pip install "pybind11[global]"
# Install some build dependencies, but avoid installing everything
# (jax, torch, ...) because we do not want to pull in a released version of
# JAX, or the wheel-based installation of CUDA. Note that when we build TE as
# part of building the JAX containers, JAX and XLA are not yet installed.
python - << EOF
import subprocess, sys, tomllib
with open("pyproject.toml", "rb") as ifile:
data = tomllib.load(ifile)
subprocess.run(
[sys.executable, "-m", "pip", "install"]
+ [r for r in data["build-system"]["requires"]
if r.startswith("nvidia-mathdx") or r.startswith("pybind11")])
EOF

# The wheel filename includes the TE commit; if this has changed since the last
# incremental build then we would end up with multiple wheels.
Expand Down
7 changes: 7 additions & 0 deletions .github/container/git-clone.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ pushd ${DESTINATION}
git checkout ${GIT_REF}
COMMIT_SHA=$(git rev-parse HEAD)
git submodule update --init --recursive
if [[ "${GIT_REPO}" == *"gitlab"* ]]; then
git remote remove origin
if grep -q -r gitlab-ci-token .git; then
grep -r gitlab-ci-token .git | awk -F: '{print $1}' | xargs rm -f
fi
git branch -D main
fi
popd

## update the manifest file
Expand Down
95 changes: 52 additions & 43 deletions .github/container/pip-finalize.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,60 @@ set -eoux pipefail

pushd /opt/pip-tools.d

# First pip-compile gathers all reqs, but we are care only about VCS installs
# It's possible there are 2nd degree transitive dependencies that are VCS, so
# this is more robust to gather VCS requirements at the cost of pip-compiling
# twice
pip-compile -o requirements.pre $(ls requirements-*.in)
# If requirements-pinned.txt exists, skip compilation
if [[ -f "requirements-pinned.txt" ]]; then
sed -E 's/#sha256=[a-f0-9]+//g' requirements-pinned.txt > requirements.txt
else
# First pip-compile gathers all reqs, but we are care only about VCS installs
# It's possible there are 2nd degree transitive dependencies that are VCS, so
# this is more robust to gather VCS requirements at the cost of pip-compiling
# twice
pip-compile -o requirements.pre $(ls requirements-*.in)

IFS=$'\n'
for line in $(cat requirements.pre | egrep '^[^#].+ @ git\+' || true); do
# VCS installs are of the form "PACKAGE @ git+..."
PACKAGE=$(echo "$line" | awk '{print $1}')
ref=$(yq e ".${PACKAGE}.latest_verified_commit" ${MANIFEST_FILE})
if [[ "$line" == *"#subdirectory="* ]]; then
# This is required b/c git-refs/commits cannot come after
# the subdirectory fragment.
# An example of an install that is of this form is:
# 'orbax-checkpoint @ git+https://github.com/google/orbax/#subdirectory=checkpoint'
echo "${line}" | sed "s/#subdirectory=/@${ref}#subdirectory=/"
else
echo "${line}@${ref}"
fi
done | tee requirements.vcs
unset IFS
IFS=$'\n'
for line in $(cat requirements.pre | egrep '^[^#].+ @ git\+' || true); do
# VCS installs are of the form "PACKAGE @ git+..."
PACKAGE=$(echo "$line" | awk '{print $1}')
ref=$(yq e ".${PACKAGE}.latest_verified_commit" ${MANIFEST_FILE})
if [[ "$line" == *"#subdirectory="* ]]; then
# This is required b/c git-refs/commits cannot come after
# the subdirectory fragment.
# An example of an install that is of this form is:
# 'orbax-checkpoint @ git+https://github.com/google/orbax/#subdirectory=checkpoint'
echo "${line}" | sed "s/#subdirectory=/@${ref}#subdirectory=/"
else
echo "${line}@${ref}"
fi
done | tee requirements.vcs
unset IFS

# Second pip-compile includes one more requirements file that pins all vcs installs
# Uses a special env var to let our custom pip impl know to treat the following as
# equivalent:
#
# fiddle @ git+https://github.com/google/fiddle
# fiddle @ git+https://github.com/google/fiddle@cd4497e4c09bdf95dcccaa1e138c2c125d32d39f
#
# JAX_TOOLBOX_VCS_EQUIVALENCY is an environment variable enabling custom logic in pip
# that treats the above as equivalent and prefers the URI wit the SHA
JAX_TOOLBOX_VCS_EQUIVALENCY=true pip-compile -o requirements.txt requirements.vcs $(ls requirements-*.in)
# Second pip-compile includes one more requirements file that pins all vcs installs
# Uses a special env var to let our custom pip impl know to treat the following as
# equivalent:
#
# fiddle @ git+https://github.com/google/fiddle
# fiddle @ git+https://github.com/google/fiddle@cd4497e4c09bdf95dcccaa1e138c2c125d32d39f
#
# JAX_TOOLBOX_VCS_EQUIVALENCY is an environment variable enabling custom logic in pip
# that treats the above as equivalent and prefers the URI wit the SHA
JAX_TOOLBOX_VCS_EQUIVALENCY=true pip-compile -o requirements.txt requirements.vcs $(ls requirements-*.in)

# If there are unpinned VCS dependencies, error since these should be included in the manifest
unpinned_vcs_dependencies=$(cat requirements.txt | egrep '^[^#].+ @ git\+' | egrep -v '^[^#].+ @ git\+.+@' || true)
if [[ $(echo -n "$unpinned_vcs_dependencies" | wc -l) -gt 0 ]]; then
echo "Unpinned VCS installs found in $(readlink -f requirements.txt):"
echo "$unpinned_vcs_dependencies"
exit 1
fi
# If there are unpinned VCS dependencies, error since these should be included in the manifest
unpinned_vcs_dependencies=$(cat requirements.txt | egrep '^[^#].+ @ git\+' | egrep -v '^[^#].+ @ git\+.+@' || true)
if [[ $(echo -n "$unpinned_vcs_dependencies" | wc -l) -gt 0 ]]; then
echo "Unpinned VCS installs found in $(readlink -f requirements.txt):"
echo "$unpinned_vcs_dependencies"
exit 1
fi

# Replace any tensorflow==X with tensorflow-cpu==X in requirements.txt only on amd64
if [ "$(uname -m)" = "x86_64" ]; then
sed -i 's/^tensorflow==\([0-9.*]\+\)$/tensorflow-cpu==\1/' requirements.txt
else
echo "Skipping TF on $(uname -m)"
# Replace any tensorflow==X with tensorflow-cpu==X in requirements.txt only on amd64
if [[ "$(uname -m)" = "x86_64" ]]; then
sed -i 's/^tensorflow==\([0-9.*]\+\)$/tensorflow-cpu==\1/' requirements.txt
else
echo "Skipping TF on $(uname -m)"
fi
fi

# --no-deps is required since conflicts can still appear during pip-sync
pip-sync --pip-args '--no-deps --src /opt' requirements.txt

Expand All @@ -63,3 +69,6 @@ for post_install in $(ls /opt/pip-tools-post-install.d/*); do
"${post_install}"
fi
done

echo "######## Frozen requirements ########"
pip freeze
10 changes: 9 additions & 1 deletion .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,15 @@ fi

readarray -t GPU_MEMORIES < <(nvidia-smi --query-gpu=memory.total --format=csv,noheader)
NGPUS="${#GPU_MEMORIES[@]}"
GPU_MEMORIES_MIB=("${GPU_MEMORIES[@]/ MiB/}")
if [[ " ${GPU_MEMORIES[*]} " =~ [[:space:]]\[N/A\][[:space:]] ]]; then
# On iGPU devices, nvidia-smi reports [N/A] GPU memory; use the system
# memory size instead to estimate what each GPU can use
SYSTEM_MEMORY_MIB=$(grep MemTotal /proc/meminfo | awk '{print $2 / 1024}')
declare -a GPU_MEMORIES_MIB
for (( i = 0; i < NGPUS; i++ )); do GPU_MEMORIES_MIB+=($(( SYSTEM_MEMORY_MIB / NGPUS ))); done
else
GPU_MEMORIES_MIB=("${GPU_MEMORIES[@]/ MiB/}")
fi

FLAGS=()

Expand Down
Loading
Loading