Skip to content

JAX/XLA Compilation Error ("ptxas too old") on RTX5090D with CUDA 12.8 #9

@A4444z

Description

@A4444z

Hello,

I'm encountering a persistent JAX/XLA compilation error while trying to run the AF3Score pipeline on a new hardware setup. The issue seems to be a forward-compatibility problem between the JAX software stack and my GPU.

My system is running Ubuntu 24.04 with an NVIDIA RTX 5090, using Driver Version 570.124.06, which supports CUDA 12.8.

The primary error I consistently receive is
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: .../bin/ptxas ptxas too old,
which points to a compiler fallback. This error has proven to be independent of the environment setup. I've reproduced it both in a native Conda environment and within a pre-built Docker container (tungmed4/af3score:latest).

I've also attempted several common workarounds, including forcing the more portable XLA backend (--flash_attention_implementation=xla) and setting XLA_FLAGS to simplify the compilation process. Unfortunately, the same ptxas too old error persists in all of these scenarios.

Given that the issue occurs across different environments and isn't resolved by standard JAX compatibility flags, it strongly suggests that the version of JAX and its underlying tools cannot correctly recognize or support this new GPU architecture. I would appreciate any guidance or insight you might have.

Thank you.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions