11use std:: { array, borrow:: BorrowMut , sync:: Arc } ;
22
3+ #[ cfg( feature = "aot" ) ]
4+ use openvm_circuit:: arch:: { VmExecutor , VmState } ;
35use openvm_circuit:: {
46 arch:: {
57 testing:: { TestBuilder , TestChipHarness , VmChipTestBuilder , BITWISE_OP_LOOKUP_BUS } ,
@@ -14,7 +16,11 @@ use openvm_circuit_primitives::{
1416 } ,
1517 var_range:: VariableRangeCheckerChip ,
1618} ;
19+ #[ cfg( feature = "aot" ) ]
20+ use openvm_instructions:: { exe:: VmExe , program:: Program , riscv:: RV32_REGISTER_AS , SystemOpcode } ;
1721use openvm_instructions:: { instruction:: Instruction , program:: PC_BITS , LocalOpcode } ;
22+ #[ cfg( feature = "aot" ) ]
23+ use openvm_rv32im_transpiler:: BaseAluOpcode :: ADD ;
1824use openvm_rv32im_transpiler:: Rv32JalrOpcode :: { self , * } ;
1925use openvm_stark_backend:: {
2026 p3_air:: BaseAir ,
4046} ;
4147
4248use super :: Rv32JalrCoreAir ;
49+ #[ cfg( feature = "aot" ) ]
50+ use crate :: Rv32ImConfig ;
4351use crate :: {
4452 adapters:: {
4553 compose, Rv32JalrAdapterAir , Rv32JalrAdapterExecutor , Rv32JalrAdapterFiller ,
@@ -364,6 +372,109 @@ fn run_jalr_sanity_test() {
364372 assert_eq ! ( rd_data, [ 252 , 36 , 14 , 47 ] ) ;
365373}
366374
375+ #[ cfg( feature = "aot" ) ]
376+ fn run_jalr_program ( instructions : Vec < Instruction < F > > ) -> ( VmState < F > , VmState < F > ) {
377+ eprintln ! ( "run_jalr_program called" ) ;
378+ let program = Program :: from_instructions ( & instructions) ;
379+ let exe = VmExe :: new ( program) ;
380+ let config = Rv32ImConfig :: default ( ) ;
381+ let memory_dimensions = config. rv32i . system . memory_config . memory_dimensions ( ) ;
382+ let executor = VmExecutor :: new ( config. clone ( ) ) . expect ( "failed to create Rv32IM executor" ) ;
383+
384+ let interpreter = executor
385+ . interpreter_instance ( & exe)
386+ . expect ( "interpreter build must succeed" ) ;
387+ let interp_state = interpreter
388+ . execute ( vec ! [ ] , None )
389+ . expect ( "interpreter execution must succeed" ) ;
390+
391+ let mut aot_instance = executor. aot_instance ( & exe) . expect ( "AOT build must succeed" ) ;
392+ let aot_state = aot_instance
393+ . execute ( vec ! [ ] , None )
394+ . expect ( "AOT execution must succeed" ) ;
395+
396+ /// TODO: add this code to AOT utils file for testing purposes to check equivalence of VMStates
397+ assert_eq ! ( interp_state. instret( ) , aot_state. instret( ) ) ;
398+ assert_eq ! ( interp_state. pc( ) , aot_state. pc( ) ) ;
399+ use openvm_circuit:: {
400+ arch:: hasher:: poseidon2:: vm_poseidon2_hasher, system:: memory:: merkle:: MerkleTree ,
401+ } ;
402+
403+ let hasher = vm_poseidon2_hasher :: < BabyBear > ( ) ;
404+
405+ let tree1 = MerkleTree :: from_memory ( & interp_state. memory . memory , & memory_dimensions, & hasher) ;
406+ let tree2 = MerkleTree :: from_memory ( & aot_state. memory . memory , & memory_dimensions, & hasher) ;
407+
408+ assert_eq ! ( tree1. root( ) , tree2. root( ) , "Memory states differ" ) ;
409+ ( interp_state, aot_state)
410+ }
411+
412+ #[ cfg( feature = "aot" ) ]
413+ fn read_register ( state : & VmState < F > , offset : usize ) -> u32 {
414+ let bytes = unsafe { state. memory . read :: < u8 , 4 > ( RV32_REGISTER_AS , offset as u32 ) } ;
415+ u32:: from_le_bytes ( bytes)
416+ }
417+
418+ #[ cfg( feature = "aot" ) ]
419+ #[ test]
420+ fn test_jalr_aot_jump_forward ( ) {
421+ eprintln ! ( "test_jalr_aot_jump_forward called" ) ;
422+ let instructions = vec ! [
423+ Instruction :: from_usize( ADD . global_opcode( ) , [ 4 , 0 , 8 , RV32_REGISTER_AS as usize , 0 ] ) ,
424+ Instruction :: from_usize(
425+ JALR . global_opcode( ) ,
426+ [ 0 , 4 , 0 , RV32_REGISTER_AS as usize , 0 , 0 , 0 ] ,
427+ ) ,
428+ Instruction :: from_isize( SystemOpcode :: TERMINATE . global_opcode( ) , 0 , 0 , 0 , 0 , 0 ) ,
429+ ] ;
430+
431+ let ( interp_state, aot_state) = run_jalr_program ( instructions) ;
432+
433+ assert_eq ! ( interp_state. instret( ) , 3 ) ;
434+ assert_eq ! ( aot_state. instret( ) , 3 ) ;
435+ assert_eq ! ( interp_state. pc( ) , 8 ) ;
436+ assert_eq ! ( aot_state. pc( ) , 8 ) ;
437+
438+ let interp_x1 = read_register ( & interp_state, 4 ) ;
439+ let aot_x1 = read_register ( & aot_state, 4 ) ;
440+ assert_eq ! ( interp_x1, 8 ) ;
441+ assert_eq ! ( interp_x1, aot_x1) ;
442+ }
443+
444+ #[ cfg( feature = "aot" ) ]
445+ #[ test]
446+ fn test_jalr_aot_writes_return_address ( ) {
447+ eprintln ! ( "test_jalr_aot_writes_return_address called" ) ;
448+ let instructions = vec ! [
449+ Instruction :: from_usize(
450+ ADD . global_opcode( ) ,
451+ [ 4 , 0 , 12 , RV32_REGISTER_AS as usize , 0 ] ,
452+ ) ,
453+ Instruction :: from_usize(
454+ JALR . global_opcode( ) ,
455+ [ 12 , 4 , 0xfffc , RV32_REGISTER_AS as usize , 0 , 1 , 1 ] ,
456+ ) ,
457+ Instruction :: from_isize( SystemOpcode :: TERMINATE . global_opcode( ) , 0 , 0 , 0 , 0 , 0 ) ,
458+ ] ;
459+
460+ let ( interp_state, aot_state) = run_jalr_program ( instructions) ;
461+
462+ assert_eq ! ( interp_state. instret( ) , 3 ) ;
463+ assert_eq ! ( aot_state. instret( ) , 3 ) ;
464+ assert_eq ! ( interp_state. pc( ) , 8 ) ;
465+ assert_eq ! ( aot_state. pc( ) , 8 ) ;
466+
467+ let interp_x1 = read_register ( & interp_state, 4 ) ;
468+ let aot_x1 = read_register ( & aot_state, 4 ) ;
469+ assert_eq ! ( interp_x1, 12 ) ;
470+ assert_eq ! ( interp_x1, aot_x1) ;
471+
472+ let interp_x3 = read_register ( & interp_state, 12 ) ;
473+ let aot_x3 = read_register ( & aot_state, 12 ) ;
474+ assert_eq ! ( interp_x3, 8 ) ;
475+ assert_eq ! ( interp_x3, aot_x3) ;
476+ }
477+
367478// ////////////////////////////////////////////////////////////////////////////////////
368479// CUDA TESTS
369480//
0 commit comments