diff --git a/crates/vm/src/arch/aot.rs b/crates/vm/src/arch/aot.rs index 4381448bb0..5a27df4dd5 100644 --- a/crates/vm/src/arch/aot.rs +++ b/crates/vm/src/arch/aot.rs @@ -51,7 +51,7 @@ pub struct AotInstance<'a, F, Ctx> { pre_compute_buf: AlignedBuf, lib: Library, pre_compute_insns_box: Box<[PreComputeInstruction<'a, F, Ctx>]>, - pc_start: u32, + pc_start: u32 } type AsmRunFn = unsafe extern "C" fn( diff --git a/extensions/rv32im/circuit/src/jalr/execution.rs b/extensions/rv32im/circuit/src/jalr/execution.rs index e99010d285..7a5ce4bd0b 100644 --- a/extensions/rv32im/circuit/src/jalr/execution.rs +++ b/extensions/rv32im/circuit/src/jalr/execution.rs @@ -10,6 +10,12 @@ use openvm_instructions::{ program::{DEFAULT_PC_STEP, PC_BITS}, riscv::RV32_REGISTER_AS, }; +#[cfg(feature = "aot")] +use crate::common::{gpr_to_rv32_register, rv32_register_to_gpr}; +#[cfg(feature = "aot")] +use openvm_instructions::{LocalOpcode, VmOpcode}; +#[cfg(feature = "aot")] +use openvm_rv32im_transpiler::Rv32JalrOpcode; use openvm_stark_backend::p3_field::PrimeField32; use super::core::Rv32JalrExecutor; @@ -54,6 +60,19 @@ macro_rules! dispatch { }; } +#[cfg(feature = "aot")] +const REG_B_W: &str = "eax"; +#[cfg(feature = "aot")] +const REG_A_W: &str = "ecx"; +#[cfg(feature = "aot")] +const REG_PC: &str = "r13"; +#[cfg(feature = "aot")] +const REG_PC_W: &str = "r13d"; +#[cfg(feature = "aot")] +const REG_INSTRET: &str = "r14"; +#[cfg(feature = "aot")] +const REG_A: &str = "rcx"; // used when building jump address + impl InterpreterExecutor for Rv32JalrExecutor where F: PrimeField32, @@ -90,8 +109,51 @@ where dispatch!(execute_e1_handler, enabled) } } + #[cfg(feature = "aot")] -impl AotExecutor for Rv32JalrExecutor where F: PrimeField32 {} +impl AotExecutor for Rv32JalrExecutor +where + F: PrimeField32, +{ + fn is_aot_supported(&self, inst: &Instruction) -> bool { + inst.opcode == Rv32JalrOpcode::JALR.global_opcode() + } + + fn generate_x86_asm(&self, inst: &Instruction, pc: u32) -> Result { + let mut asm_str = String::new(); + let to_i16 = |c: F| -> i16 { + let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32; + let c_i24 = ((c_u24 << 8) as i32) >> 8; + c_i24 as i16 + }; + let a = to_i16(inst.a); + let b = to_i16(inst.b); + if a % 4 != 0 || b % 4 != 0 { + return Err(AotError::InvalidInstruction); + } + let imm_extended = inst.c.as_canonical_u32() + inst.g.as_canonical_u32() * 0xffff0000; + let write_rd = !inst.f.is_zero(); + + asm_str += &rv32_register_to_gpr((b / 4) as u8, REG_B_W); + + asm_str += &format!(" add {}, {}\n", REG_B_W, imm_extended); + asm_str += &format!(" and {}, -2\n", REG_B_W); // clear bit 0 per RISC-V jalr + asm_str += &format!(" mov {}, {}\n", REG_PC_W, REG_B_W); // zero-extend into r13 + + if write_rd { + let next_pc = pc.wrapping_add(DEFAULT_PC_STEP); + asm_str += &format!(" mov {}, {}\n", REG_A_W, next_pc); + asm_str += &gpr_to_rv32_register(REG_A_W, (a / 4) as u8); + } + + asm_str += &format!(" add {}, 1\n", REG_INSTRET); + asm_str += " lea rdx, [rip + map_pc_base]\n"; + asm_str += &format!(" movsxd {}, [rdx + {}]\n", REG_A, REG_PC); + asm_str += " add rcx, rdx\n"; + asm_str += " jmp rcx\n"; + Ok(asm_str) + } +} impl MeteredExecutor for Rv32JalrExecutor where diff --git a/extensions/rv32im/circuit/src/jalr/tests.rs b/extensions/rv32im/circuit/src/jalr/tests.rs index e9968c23ea..99a6605584 100644 --- a/extensions/rv32im/circuit/src/jalr/tests.rs +++ b/extensions/rv32im/circuit/src/jalr/tests.rs @@ -1,5 +1,7 @@ use std::{array, borrow::BorrowMut, sync::Arc}; +#[cfg(feature = "aot")] +use openvm_circuit::arch::{VmExecutor, VmState}; use openvm_circuit::{ arch::{ testing::{TestBuilder, TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, @@ -14,7 +16,11 @@ use openvm_circuit_primitives::{ }, var_range::VariableRangeCheckerChip, }; +#[cfg(feature = "aot")] +use openvm_instructions::{exe::VmExe, program::Program, riscv::RV32_REGISTER_AS, SystemOpcode}; use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; +#[cfg(feature = "aot")] +use openvm_rv32im_transpiler::BaseAluOpcode::ADD; use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, @@ -40,6 +46,8 @@ use { }; use super::Rv32JalrCoreAir; +#[cfg(feature = "aot")] +use crate::Rv32ImConfig; use crate::{ adapters::{ compose, Rv32JalrAdapterAir, Rv32JalrAdapterExecutor, Rv32JalrAdapterFiller, @@ -364,6 +372,109 @@ fn run_jalr_sanity_test() { assert_eq!(rd_data, [252, 36, 14, 47]); } +#[cfg(feature = "aot")] +fn run_jalr_program(instructions: Vec>) -> (VmState, VmState) { + eprintln!("run_jalr_program called"); + let program = Program::from_instructions(&instructions); + let exe = VmExe::new(program); + let config = Rv32ImConfig::default(); + let memory_dimensions = config.rv32i.system.memory_config.memory_dimensions(); + let executor = VmExecutor::new(config.clone()).expect("failed to create Rv32IM executor"); + + let interpreter = executor + .interpreter_instance(&exe) + .expect("interpreter build must succeed"); + let interp_state = interpreter + .execute(vec![], None) + .expect("interpreter execution must succeed"); + + let mut aot_instance = executor.aot_instance(&exe).expect("AOT build must succeed"); + let aot_state = aot_instance + .execute(vec![], None) + .expect("AOT execution must succeed"); + + /// TODO: add this code to AOT utils file for testing purposes to check equivalence of VMStates + assert_eq!(interp_state.instret(), aot_state.instret()); + assert_eq!(interp_state.pc(), aot_state.pc()); + use openvm_circuit::{ + arch::hasher::poseidon2::vm_poseidon2_hasher, system::memory::merkle::MerkleTree, + }; + + let hasher = vm_poseidon2_hasher::(); + + let tree1 = MerkleTree::from_memory(&interp_state.memory.memory, &memory_dimensions, &hasher); + let tree2 = MerkleTree::from_memory(&aot_state.memory.memory, &memory_dimensions, &hasher); + + assert_eq!(tree1.root(), tree2.root(), "Memory states differ"); + (interp_state, aot_state) +} + +#[cfg(feature = "aot")] +fn read_register(state: &VmState, offset: usize) -> u32 { + let bytes = unsafe { state.memory.read::(RV32_REGISTER_AS, offset as u32) }; + u32::from_le_bytes(bytes) +} + +#[cfg(feature = "aot")] +#[test] +fn test_jalr_aot_jump_forward() { + eprintln!("test_jalr_aot_jump_forward called"); + let instructions = vec![ + Instruction::from_usize(ADD.global_opcode(), [4, 0, 8, RV32_REGISTER_AS as usize, 0]), + Instruction::from_usize( + JALR.global_opcode(), + [0, 4, 0, RV32_REGISTER_AS as usize, 0, 0, 0], + ), + Instruction::from_isize(SystemOpcode::TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let (interp_state, aot_state) = run_jalr_program(instructions); + + assert_eq!(interp_state.instret(), 3); + assert_eq!(aot_state.instret(), 3); + assert_eq!(interp_state.pc(), 8); + assert_eq!(aot_state.pc(), 8); + + let interp_x1 = read_register(&interp_state, 4); + let aot_x1 = read_register(&aot_state, 4); + assert_eq!(interp_x1, 8); + assert_eq!(interp_x1, aot_x1); +} + +#[cfg(feature = "aot")] +#[test] +fn test_jalr_aot_writes_return_address() { + eprintln!("test_jalr_aot_writes_return_address called"); + let instructions = vec![ + Instruction::from_usize( + ADD.global_opcode(), + [4, 0, 12, RV32_REGISTER_AS as usize, 0], + ), + Instruction::from_usize( + JALR.global_opcode(), + [12, 4, 0xfffc, RV32_REGISTER_AS as usize, 0, 1, 1], + ), + Instruction::from_isize(SystemOpcode::TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let (interp_state, aot_state) = run_jalr_program(instructions); + + assert_eq!(interp_state.instret(), 3); + assert_eq!(aot_state.instret(), 3); + assert_eq!(interp_state.pc(), 8); + assert_eq!(aot_state.pc(), 8); + + let interp_x1 = read_register(&interp_state, 4); + let aot_x1 = read_register(&aot_state, 4); + assert_eq!(interp_x1, 12); + assert_eq!(interp_x1, aot_x1); + + let interp_x3 = read_register(&interp_state, 12); + let aot_x3 = read_register(&aot_state, 12); + assert_eq!(interp_x3, 8); + assert_eq!(interp_x3, aot_x3); +} + // //////////////////////////////////////////////////////////////////////////////////// // CUDA TESTS // diff --git a/extensions/rv32im/circuit/src/mul/execution.rs b/extensions/rv32im/circuit/src/mul/execution.rs index 63e7159a1b..d3cc7620fd 100644 --- a/extensions/rv32im/circuit/src/mul/execution.rs +++ b/extensions/rv32im/circuit/src/mul/execution.rs @@ -5,6 +5,8 @@ use std::{ use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; +#[cfg(feature = "aot")] +use openvm_instructions::VmOpcode; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, @@ -14,6 +16,8 @@ use openvm_instructions::{ use openvm_rv32im_transpiler::MulOpcode; use openvm_stark_backend::p3_field::PrimeField32; +#[cfg(feature = "aot")] +use crate::common::{gpr_to_rv32_register, rv32_register_to_gpr}; use crate::MultiplicationExecutor; #[derive(AlignedBytesBorrow, Clone)] @@ -48,6 +52,22 @@ impl MultiplicationExecutor InterpreterExecutor for MultiplicationExecutor where @@ -93,6 +113,35 @@ impl AotExecutor where F: PrimeField32, { + fn is_aot_supported(&self, inst: &Instruction) -> bool { + inst.opcode == MulOpcode::MUL.global_opcode() + } + + fn generate_x86_asm(&self, inst: &Instruction, _pc: u32) -> Result { + let to_i16 = |c: F| -> i16 { + let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32; + let c_i24 = ((c_u24 << 8) as i32) >> 8; + c_i24 as i16 + }; + let a = to_i16(inst.a); + let b = to_i16(inst.b); + let c = to_i16(inst.c); + + if a % 4 != 0 || b % 4 != 0 || c % 4 != 0 { + return Err(AotError::InvalidInstruction); + } + + let mut asm_str = String::new(); + + asm_str += &rv32_register_to_gpr((b / 4) as u8, REG_B_W); + asm_str += &rv32_register_to_gpr((c / 4) as u8, REG_A_W); + asm_str += &format!(" imul {}, {}\n", REG_A_W, REG_B_W); + asm_str += &gpr_to_rv32_register(REG_A_W, (a / 4) as u8); + asm_str += &format!(" add {}, 4\n", REG_PC); + asm_str += &format!(" add {}, 1\n", REG_INSTRET); + + Ok(asm_str) + } } impl MeteredExecutor diff --git a/extensions/rv32im/circuit/src/mul/tests.rs b/extensions/rv32im/circuit/src/mul/tests.rs index b05250f53c..b5875f4008 100644 --- a/extensions/rv32im/circuit/src/mul/tests.rs +++ b/extensions/rv32im/circuit/src/mul/tests.rs @@ -1,5 +1,13 @@ +#[cfg(feature = "aot")] +use std::collections::HashMap; use std::{array, borrow::BorrowMut, sync::Arc}; +#[cfg(feature = "aot")] +use openvm_circuit::arch::{VmExecutor, VmState}; +#[cfg(feature = "aot")] +use openvm_circuit::{ + arch::hasher::poseidon2::vm_poseidon2_hasher, system::memory::merkle::MerkleTree, +}; use openvm_circuit::{ arch::{ testing::{TestBuilder, TestChipHarness, VmChipTestBuilder, RANGE_TUPLE_CHECKER_BUS}, @@ -11,6 +19,16 @@ use openvm_circuit_primitives::range_tuple::{ RangeTupleCheckerAir, RangeTupleCheckerBus, RangeTupleCheckerChip, SharedRangeTupleCheckerChip, }; use openvm_instructions::LocalOpcode; +#[cfg(feature = "aot")] +use openvm_instructions::{ + exe::VmExe, + instruction::Instruction, + program::Program, + riscv::{RV32_IMM_AS, RV32_REGISTER_AS}, + SystemOpcode, +}; +#[cfg(feature = "aot")] +use openvm_rv32im_transpiler::BaseAluOpcode::ADD; use openvm_rv32im_transpiler::MulOpcode::{self, MUL}; use openvm_stark_backend::{ p3_air::BaseAir, @@ -33,6 +51,8 @@ use { }; use super::core::run_mul; +#[cfg(feature = "aot")] +use crate::Rv32ImConfig; use crate::{ adapters::{ Rv32MultAdapterAir, Rv32MultAdapterExecutor, Rv32MultAdapterFiller, RV32_CELL_BITS, @@ -259,6 +279,198 @@ fn run_mul_sanity_test() { } } +#[cfg(feature = "aot")] +fn run_mul_program(instructions: Vec>) -> (VmState, VmState) { + let program = Program::from_instructions(&instructions); + let exe = VmExe::new(program); + let config = Rv32ImConfig::default(); + let memory_dimensions = config.rv32i.system.memory_config.memory_dimensions(); + let executor = VmExecutor::new(config.clone()).expect("failed to create Rv32IM executor"); + + let interpreter = executor + .interpreter_instance(&exe) + .expect("interpreter build must succeed"); + let interp_state = interpreter + .execute(vec![], None) + .expect("interpreter execution must succeed"); + + let mut aot_instance = executor.aot_instance(&exe).expect("AOT build must succeed"); + let aot_state = aot_instance + .execute(vec![], None) + .expect("AOT execution must succeed"); + + assert_eq!(interp_state.instret(), aot_state.instret()); + assert_eq!(interp_state.pc(), aot_state.pc()); + + let hasher = vm_poseidon2_hasher::(); + let tree1 = MerkleTree::from_memory(&interp_state.memory.memory, &memory_dimensions, &hasher); + let tree2 = MerkleTree::from_memory(&aot_state.memory.memory, &memory_dimensions, &hasher); + assert_eq!(tree1.root(), tree2.root(), "Memory states differ"); + + (interp_state, aot_state) +} + +#[cfg(feature = "aot")] +fn read_register(state: &VmState, offset: usize) -> u32 { + let bytes = unsafe { state.memory.read::(RV32_REGISTER_AS, offset as u32) }; + u32::from_le_bytes(bytes) +} + +#[cfg(feature = "aot")] +fn add_immediate(rd: usize, imm: u32) -> Instruction { + Instruction::from_usize( + ADD.global_opcode(), + [ + rd, + 0, + imm as usize, + RV32_REGISTER_AS as usize, + RV32_IMM_AS as usize, + ], + ) +} + +#[cfg(feature = "aot")] +fn mul_register(rd: usize, rs1: usize, rs2: usize) -> Instruction { + Instruction::from_usize( + MulOpcode::MUL.global_opcode(), + [ + rd, + rs1, + rs2, + RV32_REGISTER_AS as usize, + RV32_REGISTER_AS as usize, + ], + ) +} + +#[cfg(feature = "aot")] +#[test] +fn test_aot_mul_basic() { + let instructions = vec![ + add_immediate(4, 7), + add_immediate(8, 11), + mul_register(12, 4, 8), + Instruction::from_isize(SystemOpcode::TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let (interp_state, aot_state) = run_mul_program(instructions); + + assert_eq!(interp_state.instret(), 4); + assert_eq!(aot_state.instret(), 4); + + let interp_x3 = read_register(&interp_state, 12); + let aot_x3 = read_register(&aot_state, 12); + assert_eq!(interp_x3, 77); + assert_eq!(interp_x3, aot_x3); +} + +#[cfg(feature = "aot")] +#[test] +fn test_aot_mul_upper_xmm() { + let instructions = vec![ + add_immediate(4, 5), + add_immediate(12, 9), + mul_register(4, 4, 12), + Instruction::from_isize(SystemOpcode::TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let (interp_state, aot_state) = run_mul_program(instructions); + + assert_eq!(interp_state.instret(), 4); + assert_eq!(aot_state.instret(), 4); + + let interp_x1 = read_register(&interp_state, 4); + let aot_x1 = read_register(&aot_state, 4); + assert_eq!(interp_x1, 45); + assert_eq!(interp_x1, aot_x1); +} + +#[cfg(feature = "aot")] +#[test] +fn test_aot_mul_randomized_pairs() { + let offsets: [usize; 12] = [4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48]; + let mut rng = create_seeded_rng(); + let mut instructions = Vec::new(); + let mut expected = HashMap::new(); + + for &offset in &offsets { + let value_i32 = rng.gen_range(-(1i32 << 11)..(1i32 << 11)); + let imm_field = (value_i32 as u32) & 0x00FF_FFFF; + instructions.push(add_immediate(offset, imm_field)); + expected.insert(offset, value_i32 as u32); + } + + for (i, &rd_offset) in offsets.iter().enumerate() { + let rs1_offset = offsets[i]; + let rs2_offset = offsets[(i + 3) % offsets.len()]; + instructions.push(mul_register(rd_offset, rs1_offset, rs2_offset)); + + let rs1_val = *expected.get(&rs1_offset).unwrap(); + let rs2_val = *expected.get(&rs2_offset).unwrap(); + expected.insert(rd_offset, rs1_val.wrapping_mul(rs2_val)); + } + + instructions.push(Instruction::from_isize( + SystemOpcode::TERMINATE.global_opcode(), + 0, + 0, + 0, + 0, + 0, + )); + + let mul_count = offsets.len(); + let total_insts = offsets.len() + mul_count + 1; + let (interp_state, aot_state) = run_mul_program(instructions); + + assert_eq!(interp_state.instret(), total_insts as u64); + assert_eq!(aot_state.instret(), total_insts as u64); + + for (offset, expected_val) in expected { + let interp_val = read_register(&interp_state, offset); + let aot_val = read_register(&aot_state, offset); + assert_eq!( + interp_val, expected_val, + "unexpected value at offset {offset}" + ); + assert_eq!(interp_val, aot_val, "AOT mismatch at offset {offset}"); + } +} + +#[cfg(feature = "aot")] +#[test] +fn test_aot_mul_chained_dependencies() { + let instructions = vec![ + add_immediate(4, 3), + add_immediate(8, 5), + mul_register(12, 4, 8), + mul_register(4, 12, 8), + mul_register(8, 4, 12), + Instruction::from_isize(SystemOpcode::TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let (interp_state, aot_state) = run_mul_program(instructions); + + assert_eq!(interp_state.instret(), 6); + assert_eq!(aot_state.instret(), 6); + + let interp_x3 = read_register(&interp_state, 12); + let aot_x3 = read_register(&aot_state, 12); + assert_eq!(interp_x3, 15); + assert_eq!(interp_x3, aot_x3); + + let interp_x1 = read_register(&interp_state, 4); + let aot_x1 = read_register(&aot_state, 4); + assert_eq!(interp_x1, 75); + assert_eq!(interp_x1, aot_x1); + + let interp_x2 = read_register(&interp_state, 8); + let aot_x2 = read_register(&aot_state, 8); + assert_eq!(interp_x2, 1125); + assert_eq!(interp_x2, aot_x2); +} + // //////////////////////////////////////////////////////////////////////////////////// // CUDA TESTS // diff --git a/extensions/rv32im/circuit/src/mulh/execution.rs b/extensions/rv32im/circuit/src/mulh/execution.rs index 60dbc7ded5..aacf5237b0 100644 --- a/extensions/rv32im/circuit/src/mulh/execution.rs +++ b/extensions/rv32im/circuit/src/mulh/execution.rs @@ -5,6 +5,8 @@ use std::{ use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives_derive::AlignedBytesBorrow; +#[cfg(feature = "aot")] +use openvm_instructions::VmOpcode; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, @@ -14,6 +16,8 @@ use openvm_instructions::{ use openvm_rv32im_transpiler::MulHOpcode; use openvm_stark_backend::p3_field::PrimeField32; +#[cfg(feature = "aot")] +use crate::common::{gpr_to_rv32_register, rv32_register_to_gpr}; use crate::MulHExecutor; #[derive(AlignedBytesBorrow, Clone)] @@ -24,6 +28,20 @@ struct MulHPreCompute { c: u8, } +// Callee saved registers (shared with MUL AOT) +#[cfg(feature = "aot")] +const REG_PC: &str = "r13"; +#[cfg(feature = "aot")] +const REG_INSTRET: &str = "r14"; + +// Caller saved registers +#[cfg(feature = "aot")] +const REG_A_W: &str = "eax"; +#[cfg(feature = "aot")] +const REG_B_W: &str = "ecx"; +#[cfg(feature = "aot")] +const REG_TMP_W: &str = "r8d"; + impl MulHExecutor { #[inline(always)] fn pre_compute_impl( @@ -97,6 +115,58 @@ impl AotExecutor where F: PrimeField32, { + fn is_aot_supported(&self, inst: &Instruction) -> bool { + inst.opcode == MulHOpcode::MULH.global_opcode() + || inst.opcode == MulHOpcode::MULHSU.global_opcode() + || inst.opcode == MulHOpcode::MULHU.global_opcode() + } + + fn generate_x86_asm(&self, inst: &Instruction, _pc: u32) -> Result { + let to_i16 = |c: F| -> i16 { + let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32; + let c_i24 = ((c_u24 << 8) as i32) >> 8; + c_i24 as i16 + }; + + let a = to_i16(inst.a); + let b = to_i16(inst.b); + let c = to_i16(inst.c); + + if a % 4 != 0 || b % 4 != 0 || c % 4 != 0 { + return Err(AotError::InvalidInstruction); + } + + let opcode = MulHOpcode::from_usize(inst.opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET)); + + let mut asm = String::new(); + + asm += &rv32_register_to_gpr((b / 4) as u8, REG_A_W); + asm += &rv32_register_to_gpr((c / 4) as u8, REG_B_W); + + match opcode { + MulHOpcode::MULH => { + asm += " imul ecx\n"; + asm += " mov eax, edx\n"; + } + MulHOpcode::MULHSU => { + asm += &format!(" mov {REG_TMP_W}, {REG_A_W}\n"); + asm += " imul ecx\n"; + asm += " mov eax, edx\n"; + asm += " mov edx, ecx\n"; + asm += " sar edx, 31\n"; + asm += &format!(" and edx, {REG_TMP_W}\n"); + asm += " add eax, edx\n"; + } + MulHOpcode::MULHU => { + asm += " mul ecx\n"; + asm += " mov eax, edx\n"; + } + } + asm += &gpr_to_rv32_register(REG_A_W, (a / 4) as u8); + asm += &format!(" add {}, 4\n", REG_PC); + asm += &format!(" add {}, 1\n", REG_INSTRET); + Ok(asm) + } } impl MeteredExecutor diff --git a/extensions/rv32im/circuit/src/mulh/tests.rs b/extensions/rv32im/circuit/src/mulh/tests.rs index 3edd8cee87..fe5c23f6a7 100644 --- a/extensions/rv32im/circuit/src/mulh/tests.rs +++ b/extensions/rv32im/circuit/src/mulh/tests.rs @@ -1,5 +1,13 @@ +#[cfg(feature = "aot")] +use std::collections::HashMap; use std::{borrow::BorrowMut, sync::Arc}; +#[cfg(feature = "aot")] +use openvm_circuit::arch::{VmExecutor, VmState}; +#[cfg(feature = "aot")] +use openvm_circuit::{ + arch::hasher::poseidon2::vm_poseidon2_hasher, system::memory::merkle::MerkleTree, +}; use openvm_circuit::{ arch::{ testing::{ @@ -21,7 +29,16 @@ use openvm_circuit_primitives::{ SharedRangeTupleCheckerChip, }, }; +#[cfg(feature = "aot")] +use openvm_instructions::{ + exe::VmExe, + program::Program, + riscv::{RV32_IMM_AS, RV32_REGISTER_AS}, + SystemOpcode, +}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; +#[cfg(feature = "aot")] +use openvm_rv32im_transpiler::BaseAluOpcode::ADD; use openvm_rv32im_transpiler::MulHOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, @@ -34,6 +51,8 @@ use openvm_stark_backend::{ }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::rngs::StdRng; +#[cfg(feature = "aot")] +use rand::Rng; use test_case::test_case; #[cfg(feature = "cuda")] use { @@ -45,6 +64,8 @@ use { }; use super::core::run_mulh; +#[cfg(feature = "aot")] +use crate::Rv32ImConfig; use crate::{ adapters::{ Rv32MultAdapterAir, Rv32MultAdapterExecutor, Rv32MultAdapterFiller, RV32_CELL_BITS, @@ -488,6 +509,210 @@ fn run_mulhsu_neg_sanity_test() { assert_eq!(y_ext, 0); } +////////////////////////////////////////////////////////////////////////////////////// +// AOT TESTS +////////////////////////////////////////////////////////////////////////////////////// + +#[cfg(feature = "aot")] +fn run_mul_program(instructions: Vec>) -> (VmState, VmState) { + let program = Program::from_instructions(&instructions); + let exe = VmExe::new(program); + let config = Rv32ImConfig::default(); + let memory_dimensions = config.rv32i.system.memory_config.memory_dimensions(); + let executor = VmExecutor::new(config.clone()).expect("failed to create Rv32IM executor"); + + let interpreter = executor + .interpreter_instance(&exe) + .expect("interpreter build must succeed"); + let interp_state = interpreter + .execute(vec![], None) + .expect("interpreter execution must succeed"); + + let mut aot_instance = executor.aot_instance(&exe).expect("AOT build must succeed"); + let aot_state = aot_instance + .execute(vec![], None) + .expect("AOT execution must succeed"); + + assert_eq!(interp_state.instret(), aot_state.instret()); + assert_eq!(interp_state.pc(), aot_state.pc()); + + let hasher = vm_poseidon2_hasher::(); + let tree1 = MerkleTree::from_memory(&interp_state.memory.memory, &memory_dimensions, &hasher); + let tree2 = MerkleTree::from_memory(&aot_state.memory.memory, &memory_dimensions, &hasher); + assert_eq!(tree1.root(), tree2.root(), "Memory states differ"); + + (interp_state, aot_state) +} + +#[cfg(feature = "aot")] +fn read_register(state: &VmState, offset: usize) -> u32 { + let bytes = unsafe { state.memory.read::(RV32_REGISTER_AS, offset as u32) }; + u32::from_le_bytes(bytes) +} + +#[cfg(feature = "aot")] +fn add_immediate(rd: usize, imm: u32) -> Instruction { + Instruction::from_usize( + ADD.global_opcode(), + [ + rd, + 0, + imm as usize, + RV32_REGISTER_AS as usize, + RV32_IMM_AS as usize, + ], + ) +} + +#[cfg(feature = "aot")] +fn mulh_register(op: MulHOpcode, rd: usize, rs1: usize, rs2: usize) -> Instruction { + Instruction::from_usize( + op.global_opcode(), + [ + rd, + rs1, + rs2, + RV32_REGISTER_AS as usize, + RV32_REGISTER_AS as usize, + ], + ) +} + +#[cfg(feature = "aot")] +fn mulh_signed(rs1: u32, rs2: u32) -> u32 { + let prod = (rs1 as i32 as i64) * (rs2 as i32 as i64); // have to cast in this order, to sign extend properly + (prod >> 32) as u32 +} + +#[cfg(feature = "aot")] +fn mulh_signed_unsigned(rs1: u32, rs2: u32) -> u32 { + let prod = (rs1 as i32 as i128) * (rs2 as u64 as i128); + (prod >> 32) as u32 +} + +#[cfg(feature = "aot")] +fn mulh_unsigned(rs1: u32, rs2: u32) -> u32 { + let prod = (rs1 as u64) * (rs2 as u64); + (prod >> 32) as u32 +} + +#[cfg(feature = "aot")] +#[test] +fn test_aot_mulh_variants_basic() { + let instructions = vec![ + add_immediate(4, 1234), + add_immediate(8, 200), + mulh_register(MulHOpcode::MULH, 12, 4, 8), + add_immediate(16, 800), + add_immediate(20, 12345), + mulh_register(MulHOpcode::MULHSU, 24, 16, 20), + add_immediate(28, 1200), + add_immediate(32, 200), + mulh_register(MulHOpcode::MULHU, 36, 28, 32), + Instruction::from_isize(SystemOpcode::TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let (interp_state, aot_state) = run_mul_program(instructions); + + assert_eq!(interp_state.instret(), 10); + assert_eq!(aot_state.instret(), 10); + + let x3 = read_register(&interp_state, 12); + assert_eq!(x3, mulh_signed(1234, 200)); + assert_eq!(x3, read_register(&aot_state, 12)); + + let x6 = read_register(&interp_state, 24); + assert_eq!(x6, mulh_signed_unsigned(800, 12345)); + assert_eq!(x6, read_register(&aot_state, 24)); + + let x9 = read_register(&interp_state, 36); + assert_eq!(x9, mulh_unsigned(1200, 200)); + assert_eq!(x9, read_register(&aot_state, 36)); +} + +#[cfg(feature = "aot")] +#[test] +fn test_aot_mulh_upper_lane() { + let instructions = vec![ + add_immediate(4, 0x0000_000F), + add_immediate(12, 0x0000_0002), + mulh_register(MulHOpcode::MULH, 16, 4, 12), + Instruction::from_isize(SystemOpcode::TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let (interp_state, aot_state) = run_mul_program(instructions); + + assert_eq!(interp_state.instret(), 4); + assert_eq!(aot_state.instret(), 4); + + let expected = mulh_signed(0x0000_000F, 0x0000_0002); + let interp_val = read_register(&interp_state, 16); + let aot_val = read_register(&aot_state, 16); + assert_eq!(interp_val, expected); + assert_eq!(interp_val, aot_val); +} + +#[cfg(feature = "aot")] +#[test] +fn test_aot_mulh_randomized() { + let offsets: [usize; 12] = [4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48]; + let mut rng = create_seeded_rng(); + let mut instructions = Vec::new(); + let mut expected = HashMap::new(); + + for &offset in &offsets { + let value_i32 = rng.gen_range(-(1i32 << 11)..(1i32 << 11)); + let imm_field = (value_i32 as u32) & 0x00FF_FFFF; + instructions.push(add_immediate(offset, imm_field)); + expected.insert(offset, value_i32 as u32); + } + + for (i, &rd_offset) in offsets.iter().enumerate() { + let rs1_offset = offsets[i]; + let rs2_offset = offsets[(i + 4) % offsets.len()]; + let opcode = match i % 3 { + 0 => MulHOpcode::MULH, + 1 => MulHOpcode::MULHSU, + _ => MulHOpcode::MULHU, + }; + instructions.push(mulh_register(opcode, rd_offset, rs1_offset, rs2_offset)); + + let rs1_val = *expected.get(&rs1_offset).unwrap(); + let rs2_val = *expected.get(&rs2_offset).unwrap(); + let result = match opcode { + MulHOpcode::MULH => mulh_signed(rs1_val, rs2_val), + MulHOpcode::MULHSU => mulh_signed_unsigned(rs1_val, rs2_val), + MulHOpcode::MULHU => mulh_unsigned(rs1_val, rs2_val), + }; + expected.insert(rd_offset, result); + } + + instructions.push(Instruction::from_isize( + SystemOpcode::TERMINATE.global_opcode(), + 0, + 0, + 0, + 0, + 0, + )); + + let total_insts = offsets.len() + offsets.len() + 1; + let (interp_state, aot_state) = run_mul_program(instructions); + + assert_eq!(interp_state.instret(), total_insts as u64); + assert_eq!(aot_state.instret(), total_insts as u64); + + for (offset, expected_val) in expected { + let interp_val = read_register(&interp_state, offset); + let aot_val = read_register(&aot_state, offset); + assert_eq!( + interp_val, expected_val, + "unexpected value at offset {offset}" + ); + assert_eq!(interp_val, aot_val, "AOT mismatch at offset {offset}"); + } +} + // //////////////////////////////////////////////////////////////////////////////////// // CUDA TESTS //