@@ -101,10 +101,6 @@ def _patched_get_seq_length(self, layer_idx=0):
101101# The original uses x.at[indexes].set(source) which generates stablehlo.scatter
102102# that IREE can't compile (dimension map bug). For contiguous indices (like KV
103103# cache updates), dynamic_update_slice is equivalent and IREE handles it fine.
104- # Patch torchax's index_put to use dynamic_update_slice instead of scatter.
105- # The original x.at[indexes].set(values) generates stablehlo.scatter which
106- # IREE can't compile (dimension map bug). For contiguous slice updates (like
107- # KV cache), dynamic_update_slice is equivalent and IREE handles it fine.
108104import torchax .ops .jaten as _jaten # trigger op registration
109105from torchax .ops .ops_registry import all_aten_ops as _all_ops
110106
@@ -199,7 +195,7 @@ def decode_one(weights, buffers, input_ids, cache_position, past_key_values):
199195# ===================================================================
200196# Step 4: Eager-mode validation (CPU via torchax)
201197# ===================================================================
202- print (f"\n --- Step 3 : Eager-mode validation ---" )
198+ print (f"\n --- Step 4 : Eager-mode validation ---" )
203199
204200messages = [{"role" : "user" , "content" : PROMPT }]
205201prompt_ids = tokenizer .apply_chat_template (
@@ -247,65 +243,61 @@ def decode_one(weights, buffers, input_ids, cache_position, past_key_values):
247243# ===================================================================
248244# Step 5: Export to StableHLO via jax.export
249245# ===================================================================
250- print (f"\n --- Step 4: Exporting to StableHLO ---" )
246+ print (f"\n --- Step 5: Exporting to StableHLO ---" )
247+
248+ # torchax tensors wrap JAX arrays in ._elem; extract the underlying JAX array.
249+ def _unwrap_torchax (x ):
250+ return x ._elem if hasattr (x , '_elem' ) else x
251251
252- # Build abstract shape specs for jax.export.
253- # torchax tensors wrap JAX arrays in ._elem; extract to get JAX shapes/dtypes.
254252def _to_jax_shape (x ):
255- if hasattr (x , '_elem' ):
256- j = x ._elem
257- return jax .ShapeDtypeStruct (j .shape , j .dtype )
258- return jax .ShapeDtypeStruct (x .shape , x .dtype )
253+ j = _unwrap_torchax (x )
254+ return jax .ShapeDtypeStruct (j .shape , j .dtype )
259255
260256weights_shapes = jax .tree .map (_to_jax_shape , model_weights )
261257buffers_shapes = jax .tree .map (_to_jax_shape , model_buffers )
262258kv_shapes = jax .tree .map (_to_jax_shape , cache )
263259
264- # Export prefill (full prompt)
265- prefill_mlir = None
266- try :
267- prefill_shapes = (
260+ jitted_jax_decode = jax .jit (jax_decode )
261+
262+ def export_mlir (name , seq_len , filename ):
263+ """Export a single StableHLO module and write to file."""
264+ shapes = (
268265 weights_shapes ,
269266 buffers_shapes ,
270- jax .ShapeDtypeStruct ((1 , SEQ_LEN ), jnp .int32 ),
271- jax .ShapeDtypeStruct ((SEQ_LEN ,), jnp .int32 ),
267+ jax .ShapeDtypeStruct ((1 , seq_len ), jnp .int32 ),
268+ jax .ShapeDtypeStruct ((seq_len ,), jnp .int32 ),
272269 kv_shapes ,
273270 )
274- exported = jax_export .export (jax .jit (jax_decode ))(* prefill_shapes )
275- prefill_mlir = str (exported .mlir_module ())
276- with open ("qwen_prefill.mlir" , "w" ) as f :
277- f .write (prefill_mlir )
278- print (f"Prefill : { len (prefill_mlir ):,} chars -> qwen_prefill.mlir" )
271+ exported = jax_export .export (jitted_jax_decode )(* shapes )
272+ mlir = str (exported .mlir_module ())
273+ with open (filename , "w" ) as f :
274+ f .write (mlir )
275+ print (f"{ name :9s} : { len (mlir ):,} chars -> { filename } " )
276+ return mlir
277+
278+ prefill_mlir = decode_mlir = None
279+ try :
280+ prefill_mlir = export_mlir ("Prefill" , SEQ_LEN , "qwen_prefill.mlir" )
279281except Exception as e :
280282 print (f"Prefill export failed: { e } " )
281283 import traceback ; traceback .print_exc ()
282-
283- # Export decode (single token)
284- decode_mlir = None
285284try :
286- decode_shapes = (
287- weights_shapes ,
288- buffers_shapes ,
289- jax .ShapeDtypeStruct ((1 , 1 ), jnp .int32 ),
290- jax .ShapeDtypeStruct ((1 ,), jnp .int32 ),
291- kv_shapes ,
292- )
293- exported = jax_export .export (jax .jit (jax_decode ))(* decode_shapes )
294- decode_mlir = str (exported .mlir_module ())
295- with open ("qwen_decode.mlir" , "w" ) as f :
296- f .write (decode_mlir )
297- print (f"Decode : { len (decode_mlir ):,} chars -> qwen_decode.mlir" )
285+ decode_mlir = export_mlir ("Decode" , 1 , "qwen_decode.mlir" )
298286except Exception as e :
299287 print (f"Decode export failed: { e } " )
300288 import traceback ; traceback .print_exc ()
301289
290+ export_ok = (prefill_mlir is not None , decode_mlir is not None )
291+
302292# ===================================================================
303293# Step 6: Compile to IREE and load
304294# ===================================================================
305295import approx_runtime as ar
306296
307- iree_backend = "cuda" if BACKEND == "cuda" else "llvm-cpu"
308- runtime_backend = BACKEND if BACKEND == "cuda" else "cpu"
297+ BACKEND_MAP = {"cuda" : ("cuda" , "cuda" ), "llvm-cpu" : ("llvm-cpu" , "cpu" ), "cpu" : ("llvm-cpu" , "cpu" )}
298+ if BACKEND not in BACKEND_MAP :
299+ raise ValueError (f"Unknown backend { BACKEND !r} , expected one of { list (BACKEND_MAP )} " )
300+ iree_backend , runtime_backend = BACKEND_MAP [BACKEND ]
309301
310302
311303def compile_and_load (mlir_text , name , backend , rt_backend ):
@@ -326,112 +318,81 @@ def compile_and_load(mlir_text, name, backend, rt_backend):
326318
327319
328320prefill_mod = decode_mod = None
329- if prefill_mlir or decode_mlir :
330- print (f"\n --- Step 5: Compiling with IREE (backend={ iree_backend } ) ---" )
331-
332- # Track which backend each module uses (must match for buffer sharing)
333- actual_backend = iree_backend
334- actual_rt_backend = runtime_backend
335-
336- if prefill_mlir :
337- try :
338- prefill_mod = compile_and_load (prefill_mlir , "prefill" , iree_backend , runtime_backend )
339- except Exception as e :
340- print (f" Prefill IREE { iree_backend } failed, falling back to llvm-cpu..." )
341- actual_backend = "llvm-cpu"
342- actual_rt_backend = "cpu"
321+ if prefill_mlir and decode_mlir :
322+ print (f"\n --- Step 6: Compiling with IREE (backend={ iree_backend } ) ---" )
323+
324+ # Both modules must share a backend for buffer compatibility.
325+ # Try preferred backend first, fall back to llvm-cpu.
326+ backends_to_try = [(iree_backend , runtime_backend )]
327+ if iree_backend != "llvm-cpu" :
328+ backends_to_try .append (("llvm-cpu" , "cpu" ))
329+
330+ for be , rt in backends_to_try :
343331 try :
344- prefill_mod = compile_and_load (prefill_mlir , "prefill" , "llvm-cpu" , "cpu" )
345- except Exception as e2 :
346- print (f" Prefill IREE llvm-cpu also failed: { e2 } " )
347-
348- if decode_mlir :
349- try :
350- decode_mod = compile_and_load (decode_mlir , "decode" , actual_backend , actual_rt_backend )
351- except Exception as e :
352- if actual_backend != "llvm-cpu" :
353- print (f" Decode IREE { actual_backend } failed, falling back to llvm-cpu..." )
354- # Both must use same backend for buffer sharing. Re-compile prefill too.
355- actual_backend = "llvm-cpu"
356- actual_rt_backend = "cpu"
357- if prefill_mlir :
358- print (f" Re-compiling prefill on llvm-cpu for buffer compatibility..." )
359- try :
360- prefill_mod = compile_and_load (prefill_mlir , "prefill" , "llvm-cpu" , "cpu" )
361- except Exception as e3 :
362- print (f" Prefill IREE llvm-cpu failed: { e3 } " )
363- try :
364- decode_mod = compile_and_load (decode_mlir , "decode" , "llvm-cpu" , "cpu" )
365- except Exception as e2 :
366- print (f" Decode IREE llvm-cpu also failed: { e2 } " )
367- else :
368- print (f" Decode IREE llvm-cpu failed: { e } " )
332+ prefill_mod = compile_and_load (prefill_mlir , "prefill" , be , rt )
333+ decode_mod = compile_and_load (decode_mlir , "decode" , be , rt )
334+ break
335+ except Exception as e :
336+ print (f" IREE { be } failed: { e } " )
337+ if be != "llvm-cpu" :
338+ print (f" Falling back to llvm-cpu..." )
339+ prefill_mod = decode_mod = None
340+
341+ # Free MLIR strings after compilation
342+ prefill_mlir = decode_mlir = None
369343
370344# ===================================================================
371345# Step 7: Run full inference via IREE-compiled modules
372346# ===================================================================
373347if prefill_mod and decode_mod :
374- print (f"\n --- Step 6: IREE inference ---" )
375-
376- def get_main (modules ):
377- for k , mod in modules .items ():
378- if k != "hal" :
379- return mod ["main" ]
348+ print (f"\n --- Step 7: IREE inference ---" )
380349
381- iree_prefill = get_main ( prefill_mod )
382- iree_decode = get_main ( decode_mod )
350+ iree_prefill = prefill_mod . jit_call_torch [ "main" ]
351+ iree_decode = decode_mod . jit_call_torch [ "main" ]
383352
384353 # Flatten all inputs to numpy arrays in pytree order.
385- # Extract underlying JAX arrays from torchax tensors first.
386- def _to_jax ( x ):
387- return x . _elem if hasattr ( x , '_elem' ) else x
354+ weights_flat , _ = jax . tree . flatten ( jax . tree . map ( _unwrap_torchax , model_weights ))
355+ buffers_flat , _ = jax . tree . flatten ( jax . tree . map ( _unwrap_torchax , model_buffers ))
356+ cache_flat , _ = jax . tree . flatten ( jax . tree . map ( _unwrap_torchax , cache ))
388357
389- weights_flat , _ = jax . tree . flatten ( jax . tree . map ( _to_jax , model_weights ))
390- buffers_flat , _ = jax . tree . flatten ( jax . tree . map ( _to_jax , model_buffers ))
391- cache_flat , _ = jax . tree . flatten ( jax . tree . map ( _to_jax , cache ))
358+ weights_np = [ np . asarray ( w ) for w in weights_flat ]
359+ buffers_np = [ np . asarray ( b ) for b in buffers_flat ]
360+ cache_np = [ np . asarray ( c ) for c in cache_flat ]
392361
393- weights_np = [np .array (w ) for w in weights_flat ]
394- buffers_np = [np .array (b ) for b in buffers_flat ]
395- cache_np = [np .array (c ) for c in cache_flat ]
396-
397- input_ids_np = np .array (prompt_ids , dtype = np .int32 )
362+ input_ids_np = np .asarray (prompt_ids , dtype = np .int32 )
398363 cache_pos_np = np .arange (SEQ_LEN , dtype = np .int32 )
399364
400365 # -- IREE Prefill (warmup + timed) --
401366 all_prefill_inputs = weights_np + buffers_np + [input_ids_np , cache_pos_np ] + cache_np
402- prefill_out = iree_prefill (* all_prefill_inputs ) # warmup
367+ _ = iree_prefill (* all_prefill_inputs ) # warmup
368+ del _
403369
404370 t0 = time .perf_counter ()
405371 prefill_out = iree_prefill (* all_prefill_inputs )
406372 t1 = time .perf_counter ()
407373
408374 # Output structure: (logits, kv_cache_leaves...)
409- # logits is first, then StaticCache key/value tensors
410- iree_logits = np .array (prefill_out [0 ].to_host ())
375+ iree_logits = np .asarray (prefill_out [0 ].to_host ())
411376 print (f"Prefill : { t1 - t0 :.4f} s logits.shape={ iree_logits .shape } " )
412377
413378 first_tok = int (np .argmax (iree_logits [0 , - 1 , :]))
414379 print (f"First tok: '{ tokenizer .decode ([first_tok ])} '" )
415380
416381 # -- IREE Decode loop --
417- # StaticCache has fixed shape, so we can reuse the decode module
418- # for every step — only the token and cache_position change.
382+ # StaticCache has fixed shape — only the token and cache_position change.
419383 kv_buffers = [prefill_out [i ] for i in range (1 , len (prefill_out ))]
420384 iree_generated = [first_tok ]
385+ static_inputs = weights_np + buffers_np # constant across steps
421386
422387 t_decode_start = time .perf_counter ()
423388 for step in range (MAX_NEW_TOKENS - 1 ):
424389 next_tok_np = np .array ([[iree_generated [- 1 ]]], dtype = np .int32 )
425390 pos_np = np .array ([SEQ_LEN + step ], dtype = np .int32 )
426391
427- all_decode_inputs = weights_np + buffers_np + [next_tok_np , pos_np ] + kv_buffers
428- decode_out = iree_decode (* all_decode_inputs )
392+ decode_out = iree_decode (* (static_inputs + [next_tok_np , pos_np ] + kv_buffers ))
429393
430- decode_logits = np .array (decode_out [0 ].to_host ())
431- tok = int (np .argmax (decode_logits [0 , - 1 , :]))
394+ tok = int (np .argmax (np .asarray (decode_out [0 ].to_host ())[0 , - 1 , :]))
432395 iree_generated .append (tok )
433-
434- # Update KV buffers for next step
435396 kv_buffers = [decode_out [i ] for i in range (1 , len (decode_out ))]
436397
437398 if tok in stop_token_ids :
@@ -451,8 +412,8 @@ def _to_jax(x):
451412print (f"Eager output : { text } " )
452413if prefill_mod and decode_mod :
453414 print (f"IREE output : { iree_text } " )
454- print (f"Prefill MLIR : { 'OK' if prefill_mlir else 'FAILED' } " )
455- print (f"Decode MLIR : { 'OK' if decode_mlir else 'FAILED' } " )
415+ print (f"Prefill MLIR : { 'OK' if export_ok [ 0 ] else 'FAILED' } " )
416+ print (f"Decode MLIR : { 'OK' if export_ok [ 1 ] else 'FAILED' } " )
456417print (f"Prefill IREE : { 'OK' if prefill_mod else 'FAILED' } " )
457418print (f"Decode IREE : { 'OK' if decode_mod else 'FAILED' } " )
458419
0 commit comments