A high-speed Rust compiler for transforming PyTorch FX graphs into pure, stateless JAX functions.
🚀 Live Demo | Operators | Blueprint
OxideXLA reads ONNX computation graphs and emits runnable JAX Python code. It is a graph-level transpiler, not a runtime wrapper. The output is a standalone Python file that compiles under XLA with no framework dependencies beyond JAX itself.
PyTorch nn.Module --> torch.onnx.export --> .onnx file --> OxideXLA --> output.py --> jax.jit
- PyTorch dominates training. JAX dominates compilation and TPU deployment.
- Moving models between the two requires manual rewriting.
- OxideXLA automates that translation at the graph level.
A PyTorch .pth file is tightly coupled to Python's pickle serialization and PyTorch's internal object structure. JAX, built completely differently around XLA, cannot natively execute .pth files.
OxideXLA bridges this gap by using ONNX as a universal intermediary:
- Extraction: PyTorch weights and network structure are exported to ONNX.
- Translation: OxideXLA parses the ONNX graph and compiles pure, native JAX code representing the exact same architecture.
- Execution: PyTorch tensors are converted to flat NumPy arrays and fed into the newly generated JAX functions.
We aren't running .pth files in JAX; we are rewriting the architecture natively in JAX and transferring the weights.
flowchart LR
subgraph Frontend ["Frontend (Python)"]
A["PyTorch nn.Module"] --> B["torch.fx / torch.onnx"]
end
B -->|".onnx file"| C
subgraph Core ["OxideXLA Core (Rust)"]
C["Parse Protobuf"] --> D["Build Graph IR\n(petgraph DAG)"]
D --> E["Shape Inference\n+ Op Mapping"]
E --> F["Code Generation"]
end
F -->|"output.py"| G
subgraph Backend ["Backend (JAX)"]
G["@jax.jit compiled"] --> H["XLA"]
end
H --> I["TPU / GPU / CPU"]
style Core fill:#b7410e,color:#fff,stroke:#ff6633;
style Frontend fill:#1a56db,color:#fff,stroke:#5b8def;
style Backend fill:#16a34a,color:#fff,stroke:#4ade80;
OxideXLA converts object-oriented PyTorch graphs into clean, functional JAX DAGs.
graph TD
subgraph Module ["nn.Module (stateful)"]
direction TB
self["self"] --> L0["self.linear0\nweight + bias bundled"]
self --> ACT["self.relu\nstateful layer object"]
self --> L1["self.linear1\nweight + bias bundled"]
end
X["input tensor"] --> Module
Module --> Y["output tensor"]
style Module fill:#dc2626,color:#fff,stroke:#f87171;
style self fill:#7f1d1d,color:#fff;
style L0 fill:#991b1b,color:#fff;
style ACT fill:#991b1b,color:#fff;
style L1 fill:#991b1b,color:#fff;
graph TD
X["input\n[1, 4]"] --> MM0["MatMul\n[1, 8]"]
W0["params 'layer0_weight'\n[4, 8]"] --> MM0
MM0 --> ADD0["Add\n[1, 8]"]
B0["params 'layer0_bias'\n[8]"] --> ADD0
ADD0 --> RELU["ReLU\n[1, 8]"]
RELU --> MM1["MatMul\n[1, 3]"]
W1["params 'layer1_weight'\n[8, 3]"] --> MM1
MM1 --> ADD1["Add\n[1, 3]"]
B1["params 'layer1_bias'\n[3]"] --> ADD1
ADD1 --> SM["Softmax\n[1, 3]"]
SM --> OUT["output"]
style X fill:#1e40af,color:#fff;
style OUT fill:#16a34a,color:#fff;
style W0 fill:#6b21a8,color:#fff;
style B0 fill:#6b21a8,color:#fff;
style W1 fill:#6b21a8,color:#fff;
style B1 fill:#6b21a8,color:#fff;
style MM0 fill:#0e7490,color:#fff;
style ADD0 fill:#0e7490,color:#fff;
style RELU fill:#0e7490,color:#fff;
style MM1 fill:#0e7490,color:#fff;
style ADD1 fill:#0e7490,color:#fff;
style SM fill:#0e7490,color:#fff;
Input: two_layer_mlp.onnx (MatMul -> Add -> ReLU -> MatMul -> Add -> Softmax)
Process: OxideXLA transpiles in < 1ms
Output: Pure JAX function
Generated code (sample_jax_model.py):
# Generated by OxideXLA
import jax
import jax.numpy as jnp
import jax.lax
@jax.jit
def forward(params, input):
matmul_0 = jnp.matmul(input, params['layer0.weight'])
add_0 = jnp.add(matmul_0, params['layer0.bias'])
relu_0 = jax.nn.relu(add_0)
matmul_1 = jnp.matmul(relu_0, params['layer1.weight'])
add_1 = jnp.add(matmul_1, params['layer1.bias'])
softmax_0 = jax.nn.softmax(add_1, axis=-1)
return softmax_0OxideXLA is built to be a high-fidelity compiler. Every transpiled model must produce numerically identical results to its source framework.
| Model Architecture | Input Shape | Torch Sum | JAX Sum | Mean Squared Error (MSE) | Speedup | Status |
|---|---|---|---|---|---|---|
| 3-Layer MLP | (1, 64) | 1.00000 | 1.00000 | 2.22e-17 | 0.51x | ✓ |
| CNN-Simple* | (1, 3, 16, 16) | 14.5023 | 14.5023 | 1.35e-16 | 0.82x | ✓ |
| ResNet-18 | (1, 3, 224, 224) | 0.8846 (Prob) | 0.8846 (Prob) | 6.70e-12 | - | ✓ |
| DeiT-Small | (1, 3, 224, 224) | 0.1608 (Prob) | 0.1608 (Prob) | 3.62e-11 | - | ✓ |
| Text Context (NLP) | (1, 15) tokens | 0.5003 (Prob) | 0.5003 (Prob) | 1.39e-17 | 0.35x | ✓ |
*CNN results measured on a standard Conv2D + ReLU + Linear block.
To ensure long-term stability, we track four critical metrics:
- The "Identity" Test: Loading Torch weights into JAX dictionaries. Measured MSE must be
< 10^-10. - The "Shape Inference" Test: Validating that every XLA-compiled node has correct dimensions. Prevents runtime
ShapeMismatcherrors. - The "Stateless" Test: Ensuring BatchNorm buffers (running mean/var) are correctly mapped to JAX
params. - The "End-to-End" Benchmark: Quantifying the execution speed of
jax.jit(oxide_fn)vs PyTorch eager mode.
OxideXLA supports traversing the complex attention mechanisms and layer normalizations used in modern Transformers. You can easily transpile a Data-efficient Image Transformer (DeiT) straight from Hugging Face:
# 1. Run the test script which downloads and exports DeiT-Small to ONNX
python3 tests/verify_deit.py
# Expected Output:
# PyTorch Predicted ID: 285, Prob: 0.1608
# JAX Predicted ID: 285, Prob: 0.1608
# --- Verification ---
# MSE: 3.62e-11
# Status: SUCCESS - Class match verifiedThe script automatically generates deit_jax.py in your working directory containing pure, unrolled jax.numpy instructions.
# Build from source
cargo install --path .
cargo build --release
# Inspect an ONNX graph (ASCII)
oxide_xla inspect model.onnx
# Inspect in JSON format
oxide_xla inspect model.onnx --format json
# Transpile to JAX
oxide_xla compile model.onnx --output model_jax.py
# View the generated JAX code
cat model_jax.pyOxideXLA produces pure, stateless JAX code where parameters are separated from computation:
# Generated by OxideXLA
import jax
import jax.numpy as jnp
@jax.jit
def forward(params, x):
x = jnp.matmul(x, params['linear0']['weight'])
x = jnp.add(x, params['linear0']['bias'])
x = jax.nn.relu(x)
return x| ONNX Op | JAX Equivalent | Status |
|---|---|---|
| MatMul | jnp.matmul |
Done |
| Add | jnp.add |
Done |
| Mul | jnp.multiply |
Done |
| Sub | jnp.subtract |
Done |
| Relu | jax.nn.relu |
Done |
| Sigmoid | jax.nn.sigmoid |
Done |
| Tanh | jnp.tanh |
Done |
| Softmax | jax.nn.softmax |
Done |
| Reshape | jnp.reshape |
Done |
| Transpose | jnp.transpose |
Done |
| Conv | jax.lax.conv_general_dilated |
Done |
| BatchNorm | Manual decomposition | Done |
flowchart LR
INPUT["ONNX op_type\nstring"] --> MATCH{"Match against\nknown ops"}
MATCH -->|"found"| BUILD["Build JaxOp\nwith attributes"]
MATCH -->|"not found"| ERR["UnsupportedOp\nerror"]
BUILD --> EMIT["Emit JAX\nPython code"]
style INPUT fill:#1e3a5f,color:#fff
style MATCH fill:#92400e,color:#fff
style BUILD fill:#14532d,color:#fff
style EMIT fill:#4a1d96,color:#fff
style ERR fill:#7f1d1d,color:#fff
Every tensor operation in the graph is represented as an IrNode with its
associated JaxOp variant. This is the central data structure of the compiler.
classDiagram
class IrNode {
+JaxOp op
+String name
+Vec~i64~ output_shape
+DType dtype
+Option~String~ param_key
+Vec~NodeIndex~ ordered_inputs
}
class JaxOp {
<<enumeration>>
Input
Param
MatMul
Add
Mul
Sub
Relu
Sigmoid
Tanh
Softmax
Reshape
Transpose
Conv
BatchNorm
}
class DType {
<<enumeration>>
Float32
Float64
Int32
Int64
Unknown
}
IrNode --> JaxOp : contains
IrNode --> DType : contains
oxide-xla/
src/
main.rs CLI entry point (clap)
lib.rs Public API
parser/
mod.rs Module root
onnx_loader.rs Protobuf -> OnnxModel (inline proto defs)
graph/
mod.rs Module root
dag.rs IrGraph, IrNode, JaxOp (petgraph DAG)
shape.rs Static shape inference engine
ops/
mod.rs Central op dispatch
math.rs MatMul, Add, Mul, Sub
nn.rs Relu, Sigmoid, Softmax, Conv, BatchNorm
reshape.rs Reshape, Transpose
codegen/
mod.rs Module root
module.rs Top-level Python module assembly
emit.rs Per-node JAX code emission
bridge/
fx_export.py Python FX -> ONNX exporter (Stage 3)
requirements.txt Python dependencies
tests/
test_parser.rs Parser integration tests
test_graph.rs Graph + shape inference tests
test_codegen.rs Full pipeline tests
generate_fixtures.py Generates .onnx test fixtures
models/ .onnx fixture files
docs/
architecture.md Detailed architecture with Mermaid diagrams
operator_mapping.md Operator reference and shape rules
gantt
title OxideXLA Development Roadmap
dateFormat YYYY-MM-DD
axisFormat %b %d
section Stage 1
ONNX Graph Inspector CLI :done, s1, 2025-06-01, 14d
section Stage 2
ONNX to JAX Codegen :done, s2, after s1, 28d
section Stage 3
PyTorch FX Bridge :active, s3, after s2, 14d
section Stretch
WASM Compilation :s4, after s3, 14d
StableHLO Backend :s5, after s4, 21d
Operator Fusion :s6, after s4, 21d
| Stage | Deliverable | Status |
|---|---|---|
| 1 | Parse and display ONNX graphs (ASCII + JSON) | Done |
| 2 | Emit runnable JAX Python from ONNX graphs | Done |
| 3 | Python shim to trace PyTorch models end-to-end | In Progress |
sequenceDiagram
participant User
participant Bridge as fx_export.py
participant Torch as torch.fx
participant Export as torch.onnx
participant OxideXLA as OxideXLA (Rust)
User->>Bridge: python fx_export.py --model resnet18
Bridge->>Torch: symbolic_trace(model)
Torch-->>Bridge: FX Graph
Bridge->>Export: export to ONNX
Export-->>Bridge: .onnx bytes
Bridge->>OxideXLA: pipe via stdin
OxideXLA-->>User: output.py (pure JAX)
# Full pipeline (Stage 3)
python bridge/fx_export.py --model torchvision.models.resnet18 | oxide_xla compile - --output r18.py13 integration tests covering the full pipeline:
test_parser (4 tests)
parse_linear_model ............... ok
parse_linear_relu_model .......... ok
parse_two_layer_mlp .............. ok
parse_nonexistent_file_returns_error ok
test_graph (5 tests)
build_ir_from_linear_model ....... ok
build_ir_from_two_layer_mlp ..... ok
shape_inference_linear_model ..... ok
ascii_output_not_empty ........... ok
json_output_is_valid ............. ok
test_codegen (4 tests)
codegen_linear_model ............. ok
codegen_linear_relu_model ........ ok
codegen_two_layer_mlp ............ ok
codegen_output_is_syntactically_complete ok
cargo test| Concern | Rust Advantage |
|---|---|
| Speed | Graph traversal and code generation run in microseconds, not seconds |
| Safety | No null pointer crashes, no GIL -- graph operations are memory-safe by construction |
| WASM | The compiler can be compiled to WebAssembly for browser-based transpilation |
| Ecosystem | petgraph, prost, clap -- mature crates for graphs, protobuf, and CLI |
| Component | Technology |
|---|---|
| Core compiler | Rust |
| Graph IR | petgraph (DiGraph) |
| ONNX parsing | prost (protobuf) |
| CLI | clap 4 |
| Serialization | serde + serde_json |
| Shape inference | Custom (per-op rules) |
| Code generation | String-based emitter |
| Python bridge | torch.fx + torch.onnx |
| Testing | cargo test + .onnx fixtures |
Apache 2.0
OxideXLA -- Jeffrey Asante -- 2026.