-
Notifications
You must be signed in to change notification settings - Fork 19
Description
Hello,
I'm encountering a persistent JAX/XLA compilation error while trying to run the AF3Score pipeline on a new hardware setup. The issue seems to be a forward-compatibility problem between the JAX software stack and my GPU.
My system is running Ubuntu 24.04 with an NVIDIA RTX 5090, using Driver Version 570.124.06, which supports CUDA 12.8.
The primary error I consistently receive is
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: .../bin/ptxas ptxas too old,
which points to a compiler fallback. This error has proven to be independent of the environment setup. I've reproduced it both in a native Conda environment and within a pre-built Docker container (tungmed4/af3score:latest).
I've also attempted several common workarounds, including forcing the more portable XLA backend (--flash_attention_implementation=xla) and setting XLA_FLAGS to simplify the compilation process. Unfortunately, the same ptxas too old error persists in all of these scenarios.
Given that the issue occurs across different environments and isn't resolved by standard JAX compatibility flags, it strongly suggests that the version of JAX and its underlying tools cannot correctly recognize or support this new GPU architecture. I would appreciate any guidance or insight you might have.
Thank you.