Skip to content

Commit d0160ea

Browse files
MooMooHorseclaude
andcommitted
Simplify Qwen IREE example: deduplicate export/fallback, fix step numbering
- Extract export_mlir() helper to eliminate duplicate prefill/decode export blocks - Unify _to_jax_shape/_to_jax into shared _unwrap_torchax helper - Replace nested try/except fallback cascade with clean backend loop - Add backend validation via BACKEND_MAP - Use np.asarray instead of np.array to avoid unnecessary copies - Pre-build static_inputs (weights+buffers) outside decode loop - Free MLIR strings after compilation to reduce memory - Fix step numbering mismatch between comments and print statements Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 301df2b commit d0160ea

1 file changed

Lines changed: 72 additions & 111 deletions

File tree

runtime/examples/example_qwen_iree.py

Lines changed: 72 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,6 @@ def _patched_get_seq_length(self, layer_idx=0):
101101
# The original uses x.at[indexes].set(source) which generates stablehlo.scatter
102102
# that IREE can't compile (dimension map bug). For contiguous indices (like KV
103103
# cache updates), dynamic_update_slice is equivalent and IREE handles it fine.
104-
# Patch torchax's index_put to use dynamic_update_slice instead of scatter.
105-
# The original x.at[indexes].set(values) generates stablehlo.scatter which
106-
# IREE can't compile (dimension map bug). For contiguous slice updates (like
107-
# KV cache), dynamic_update_slice is equivalent and IREE handles it fine.
108104
import torchax.ops.jaten as _jaten # trigger op registration
109105
from torchax.ops.ops_registry import all_aten_ops as _all_ops
110106

@@ -199,7 +195,7 @@ def decode_one(weights, buffers, input_ids, cache_position, past_key_values):
199195
# ===================================================================
200196
# Step 4: Eager-mode validation (CPU via torchax)
201197
# ===================================================================
202-
print(f"\n--- Step 3: Eager-mode validation ---")
198+
print(f"\n--- Step 4: Eager-mode validation ---")
203199

204200
messages = [{"role": "user", "content": PROMPT}]
205201
prompt_ids = tokenizer.apply_chat_template(
@@ -247,65 +243,61 @@ def decode_one(weights, buffers, input_ids, cache_position, past_key_values):
247243
# ===================================================================
248244
# Step 5: Export to StableHLO via jax.export
249245
# ===================================================================
250-
print(f"\n--- Step 4: Exporting to StableHLO ---")
246+
print(f"\n--- Step 5: Exporting to StableHLO ---")
247+
248+
# torchax tensors wrap JAX arrays in ._elem; extract the underlying JAX array.
249+
def _unwrap_torchax(x):
250+
return x._elem if hasattr(x, '_elem') else x
251251

252-
# Build abstract shape specs for jax.export.
253-
# torchax tensors wrap JAX arrays in ._elem; extract to get JAX shapes/dtypes.
254252
def _to_jax_shape(x):
255-
if hasattr(x, '_elem'):
256-
j = x._elem
257-
return jax.ShapeDtypeStruct(j.shape, j.dtype)
258-
return jax.ShapeDtypeStruct(x.shape, x.dtype)
253+
j = _unwrap_torchax(x)
254+
return jax.ShapeDtypeStruct(j.shape, j.dtype)
259255

260256
weights_shapes = jax.tree.map(_to_jax_shape, model_weights)
261257
buffers_shapes = jax.tree.map(_to_jax_shape, model_buffers)
262258
kv_shapes = jax.tree.map(_to_jax_shape, cache)
263259

264-
# Export prefill (full prompt)
265-
prefill_mlir = None
266-
try:
267-
prefill_shapes = (
260+
jitted_jax_decode = jax.jit(jax_decode)
261+
262+
def export_mlir(name, seq_len, filename):
263+
"""Export a single StableHLO module and write to file."""
264+
shapes = (
268265
weights_shapes,
269266
buffers_shapes,
270-
jax.ShapeDtypeStruct((1, SEQ_LEN), jnp.int32),
271-
jax.ShapeDtypeStruct((SEQ_LEN,), jnp.int32),
267+
jax.ShapeDtypeStruct((1, seq_len), jnp.int32),
268+
jax.ShapeDtypeStruct((seq_len,), jnp.int32),
272269
kv_shapes,
273270
)
274-
exported = jax_export.export(jax.jit(jax_decode))(*prefill_shapes)
275-
prefill_mlir = str(exported.mlir_module())
276-
with open("qwen_prefill.mlir", "w") as f:
277-
f.write(prefill_mlir)
278-
print(f"Prefill : {len(prefill_mlir):,} chars -> qwen_prefill.mlir")
271+
exported = jax_export.export(jitted_jax_decode)(*shapes)
272+
mlir = str(exported.mlir_module())
273+
with open(filename, "w") as f:
274+
f.write(mlir)
275+
print(f"{name:9s}: {len(mlir):,} chars -> {filename}")
276+
return mlir
277+
278+
prefill_mlir = decode_mlir = None
279+
try:
280+
prefill_mlir = export_mlir("Prefill", SEQ_LEN, "qwen_prefill.mlir")
279281
except Exception as e:
280282
print(f"Prefill export failed: {e}")
281283
import traceback; traceback.print_exc()
282-
283-
# Export decode (single token)
284-
decode_mlir = None
285284
try:
286-
decode_shapes = (
287-
weights_shapes,
288-
buffers_shapes,
289-
jax.ShapeDtypeStruct((1, 1), jnp.int32),
290-
jax.ShapeDtypeStruct((1,), jnp.int32),
291-
kv_shapes,
292-
)
293-
exported = jax_export.export(jax.jit(jax_decode))(*decode_shapes)
294-
decode_mlir = str(exported.mlir_module())
295-
with open("qwen_decode.mlir", "w") as f:
296-
f.write(decode_mlir)
297-
print(f"Decode : {len(decode_mlir):,} chars -> qwen_decode.mlir")
285+
decode_mlir = export_mlir("Decode", 1, "qwen_decode.mlir")
298286
except Exception as e:
299287
print(f"Decode export failed: {e}")
300288
import traceback; traceback.print_exc()
301289

290+
export_ok = (prefill_mlir is not None, decode_mlir is not None)
291+
302292
# ===================================================================
303293
# Step 6: Compile to IREE and load
304294
# ===================================================================
305295
import approx_runtime as ar
306296

307-
iree_backend = "cuda" if BACKEND == "cuda" else "llvm-cpu"
308-
runtime_backend = BACKEND if BACKEND == "cuda" else "cpu"
297+
BACKEND_MAP = {"cuda": ("cuda", "cuda"), "llvm-cpu": ("llvm-cpu", "cpu"), "cpu": ("llvm-cpu", "cpu")}
298+
if BACKEND not in BACKEND_MAP:
299+
raise ValueError(f"Unknown backend {BACKEND!r}, expected one of {list(BACKEND_MAP)}")
300+
iree_backend, runtime_backend = BACKEND_MAP[BACKEND]
309301

310302

311303
def compile_and_load(mlir_text, name, backend, rt_backend):
@@ -326,112 +318,81 @@ def compile_and_load(mlir_text, name, backend, rt_backend):
326318

327319

328320
prefill_mod = decode_mod = None
329-
if prefill_mlir or decode_mlir:
330-
print(f"\n--- Step 5: Compiling with IREE (backend={iree_backend}) ---")
331-
332-
# Track which backend each module uses (must match for buffer sharing)
333-
actual_backend = iree_backend
334-
actual_rt_backend = runtime_backend
335-
336-
if prefill_mlir:
337-
try:
338-
prefill_mod = compile_and_load(prefill_mlir, "prefill", iree_backend, runtime_backend)
339-
except Exception as e:
340-
print(f" Prefill IREE {iree_backend} failed, falling back to llvm-cpu...")
341-
actual_backend = "llvm-cpu"
342-
actual_rt_backend = "cpu"
321+
if prefill_mlir and decode_mlir:
322+
print(f"\n--- Step 6: Compiling with IREE (backend={iree_backend}) ---")
323+
324+
# Both modules must share a backend for buffer compatibility.
325+
# Try preferred backend first, fall back to llvm-cpu.
326+
backends_to_try = [(iree_backend, runtime_backend)]
327+
if iree_backend != "llvm-cpu":
328+
backends_to_try.append(("llvm-cpu", "cpu"))
329+
330+
for be, rt in backends_to_try:
343331
try:
344-
prefill_mod = compile_and_load(prefill_mlir, "prefill", "llvm-cpu", "cpu")
345-
except Exception as e2:
346-
print(f" Prefill IREE llvm-cpu also failed: {e2}")
347-
348-
if decode_mlir:
349-
try:
350-
decode_mod = compile_and_load(decode_mlir, "decode", actual_backend, actual_rt_backend)
351-
except Exception as e:
352-
if actual_backend != "llvm-cpu":
353-
print(f" Decode IREE {actual_backend} failed, falling back to llvm-cpu...")
354-
# Both must use same backend for buffer sharing. Re-compile prefill too.
355-
actual_backend = "llvm-cpu"
356-
actual_rt_backend = "cpu"
357-
if prefill_mlir:
358-
print(f" Re-compiling prefill on llvm-cpu for buffer compatibility...")
359-
try:
360-
prefill_mod = compile_and_load(prefill_mlir, "prefill", "llvm-cpu", "cpu")
361-
except Exception as e3:
362-
print(f" Prefill IREE llvm-cpu failed: {e3}")
363-
try:
364-
decode_mod = compile_and_load(decode_mlir, "decode", "llvm-cpu", "cpu")
365-
except Exception as e2:
366-
print(f" Decode IREE llvm-cpu also failed: {e2}")
367-
else:
368-
print(f" Decode IREE llvm-cpu failed: {e}")
332+
prefill_mod = compile_and_load(prefill_mlir, "prefill", be, rt)
333+
decode_mod = compile_and_load(decode_mlir, "decode", be, rt)
334+
break
335+
except Exception as e:
336+
print(f" IREE {be} failed: {e}")
337+
if be != "llvm-cpu":
338+
print(f" Falling back to llvm-cpu...")
339+
prefill_mod = decode_mod = None
340+
341+
# Free MLIR strings after compilation
342+
prefill_mlir = decode_mlir = None
369343

370344
# ===================================================================
371345
# Step 7: Run full inference via IREE-compiled modules
372346
# ===================================================================
373347
if prefill_mod and decode_mod:
374-
print(f"\n--- Step 6: IREE inference ---")
375-
376-
def get_main(modules):
377-
for k, mod in modules.items():
378-
if k != "hal":
379-
return mod["main"]
348+
print(f"\n--- Step 7: IREE inference ---")
380349

381-
iree_prefill = get_main(prefill_mod)
382-
iree_decode = get_main(decode_mod)
350+
iree_prefill = prefill_mod.jit_call_torch["main"]
351+
iree_decode = decode_mod.jit_call_torch["main"]
383352

384353
# Flatten all inputs to numpy arrays in pytree order.
385-
# Extract underlying JAX arrays from torchax tensors first.
386-
def _to_jax(x):
387-
return x._elem if hasattr(x, '_elem') else x
354+
weights_flat, _ = jax.tree.flatten(jax.tree.map(_unwrap_torchax, model_weights))
355+
buffers_flat, _ = jax.tree.flatten(jax.tree.map(_unwrap_torchax, model_buffers))
356+
cache_flat, _ = jax.tree.flatten(jax.tree.map(_unwrap_torchax, cache))
388357

389-
weights_flat, _ = jax.tree.flatten(jax.tree.map(_to_jax, model_weights))
390-
buffers_flat, _ = jax.tree.flatten(jax.tree.map(_to_jax, model_buffers))
391-
cache_flat, _ = jax.tree.flatten(jax.tree.map(_to_jax, cache))
358+
weights_np = [np.asarray(w) for w in weights_flat]
359+
buffers_np = [np.asarray(b) for b in buffers_flat]
360+
cache_np = [np.asarray(c) for c in cache_flat]
392361

393-
weights_np = [np.array(w) for w in weights_flat]
394-
buffers_np = [np.array(b) for b in buffers_flat]
395-
cache_np = [np.array(c) for c in cache_flat]
396-
397-
input_ids_np = np.array(prompt_ids, dtype=np.int32)
362+
input_ids_np = np.asarray(prompt_ids, dtype=np.int32)
398363
cache_pos_np = np.arange(SEQ_LEN, dtype=np.int32)
399364

400365
# -- IREE Prefill (warmup + timed) --
401366
all_prefill_inputs = weights_np + buffers_np + [input_ids_np, cache_pos_np] + cache_np
402-
prefill_out = iree_prefill(*all_prefill_inputs) # warmup
367+
_ = iree_prefill(*all_prefill_inputs) # warmup
368+
del _
403369

404370
t0 = time.perf_counter()
405371
prefill_out = iree_prefill(*all_prefill_inputs)
406372
t1 = time.perf_counter()
407373

408374
# Output structure: (logits, kv_cache_leaves...)
409-
# logits is first, then StaticCache key/value tensors
410-
iree_logits = np.array(prefill_out[0].to_host())
375+
iree_logits = np.asarray(prefill_out[0].to_host())
411376
print(f"Prefill : {t1-t0:.4f}s logits.shape={iree_logits.shape}")
412377

413378
first_tok = int(np.argmax(iree_logits[0, -1, :]))
414379
print(f"First tok: '{tokenizer.decode([first_tok])}'")
415380

416381
# -- IREE Decode loop --
417-
# StaticCache has fixed shape, so we can reuse the decode module
418-
# for every step — only the token and cache_position change.
382+
# StaticCache has fixed shape — only the token and cache_position change.
419383
kv_buffers = [prefill_out[i] for i in range(1, len(prefill_out))]
420384
iree_generated = [first_tok]
385+
static_inputs = weights_np + buffers_np # constant across steps
421386

422387
t_decode_start = time.perf_counter()
423388
for step in range(MAX_NEW_TOKENS - 1):
424389
next_tok_np = np.array([[iree_generated[-1]]], dtype=np.int32)
425390
pos_np = np.array([SEQ_LEN + step], dtype=np.int32)
426391

427-
all_decode_inputs = weights_np + buffers_np + [next_tok_np, pos_np] + kv_buffers
428-
decode_out = iree_decode(*all_decode_inputs)
392+
decode_out = iree_decode(*(static_inputs + [next_tok_np, pos_np] + kv_buffers))
429393

430-
decode_logits = np.array(decode_out[0].to_host())
431-
tok = int(np.argmax(decode_logits[0, -1, :]))
394+
tok = int(np.argmax(np.asarray(decode_out[0].to_host())[0, -1, :]))
432395
iree_generated.append(tok)
433-
434-
# Update KV buffers for next step
435396
kv_buffers = [decode_out[i] for i in range(1, len(decode_out))]
436397

437398
if tok in stop_token_ids:
@@ -451,8 +412,8 @@ def _to_jax(x):
451412
print(f"Eager output : {text}")
452413
if prefill_mod and decode_mod:
453414
print(f"IREE output : {iree_text}")
454-
print(f"Prefill MLIR : {'OK' if prefill_mlir else 'FAILED'}")
455-
print(f"Decode MLIR : {'OK' if decode_mlir else 'FAILED'}")
415+
print(f"Prefill MLIR : {'OK' if export_ok[0] else 'FAILED'}")
416+
print(f"Decode MLIR : {'OK' if export_ok[1] else 'FAILED'}")
456417
print(f"Prefill IREE : {'OK' if prefill_mod else 'FAILED'}")
457418
print(f"Decode IREE : {'OK' if decode_mod else 'FAILED'}")
458419

0 commit comments

Comments
 (0)