Skip to content

Commit adb82aa

Browse files
MaillewUbuntunyunyunyunyu
authored
feat: aot lightweight e1 rv32 mul, mulh, jalr (#2189)
Closes INT-5263, INT-5262 --------- Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Xinding Wei <[email protected]>
1 parent 8256d66 commit adb82aa

File tree

7 files changed

+731
-2
lines changed

7 files changed

+731
-2
lines changed

crates/vm/src/arch/aot.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub struct AotInstance<'a, F, Ctx> {
5151
pre_compute_buf: AlignedBuf,
5252
lib: Library,
5353
pre_compute_insns_box: Box<[PreComputeInstruction<'a, F, Ctx>]>,
54-
pc_start: u32,
54+
pc_start: u32
5555
}
5656

5757
type AsmRunFn = unsafe extern "C" fn(

extensions/rv32im/circuit/src/jalr/execution.rs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ use openvm_instructions::{
1010
program::{DEFAULT_PC_STEP, PC_BITS},
1111
riscv::RV32_REGISTER_AS,
1212
};
13+
#[cfg(feature = "aot")]
14+
use crate::common::{gpr_to_rv32_register, rv32_register_to_gpr};
15+
#[cfg(feature = "aot")]
16+
use openvm_instructions::{LocalOpcode, VmOpcode};
17+
#[cfg(feature = "aot")]
18+
use openvm_rv32im_transpiler::Rv32JalrOpcode;
1319
use openvm_stark_backend::p3_field::PrimeField32;
1420

1521
use super::core::Rv32JalrExecutor;
@@ -54,6 +60,19 @@ macro_rules! dispatch {
5460
};
5561
}
5662

63+
#[cfg(feature = "aot")]
64+
const REG_B_W: &str = "eax";
65+
#[cfg(feature = "aot")]
66+
const REG_A_W: &str = "ecx";
67+
#[cfg(feature = "aot")]
68+
const REG_PC: &str = "r13";
69+
#[cfg(feature = "aot")]
70+
const REG_PC_W: &str = "r13d";
71+
#[cfg(feature = "aot")]
72+
const REG_INSTRET: &str = "r14";
73+
#[cfg(feature = "aot")]
74+
const REG_A: &str = "rcx"; // used when building jump address
75+
5776
impl<F, A> InterpreterExecutor<F> for Rv32JalrExecutor<A>
5877
where
5978
F: PrimeField32,
@@ -90,8 +109,51 @@ where
90109
dispatch!(execute_e1_handler, enabled)
91110
}
92111
}
112+
93113
#[cfg(feature = "aot")]
94-
impl<F, A> AotExecutor<F> for Rv32JalrExecutor<A> where F: PrimeField32 {}
114+
impl<F, A> AotExecutor<F> for Rv32JalrExecutor<A>
115+
where
116+
F: PrimeField32,
117+
{
118+
fn is_aot_supported(&self, inst: &Instruction<F>) -> bool {
119+
inst.opcode == Rv32JalrOpcode::JALR.global_opcode()
120+
}
121+
122+
fn generate_x86_asm(&self, inst: &Instruction<F>, pc: u32) -> Result<String, AotError> {
123+
let mut asm_str = String::new();
124+
let to_i16 = |c: F| -> i16 {
125+
let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
126+
let c_i24 = ((c_u24 << 8) as i32) >> 8;
127+
c_i24 as i16
128+
};
129+
let a = to_i16(inst.a);
130+
let b = to_i16(inst.b);
131+
if a % 4 != 0 || b % 4 != 0 {
132+
return Err(AotError::InvalidInstruction);
133+
}
134+
let imm_extended = inst.c.as_canonical_u32() + inst.g.as_canonical_u32() * 0xffff0000;
135+
let write_rd = !inst.f.is_zero();
136+
137+
asm_str += &rv32_register_to_gpr((b / 4) as u8, REG_B_W);
138+
139+
asm_str += &format!(" add {}, {}\n", REG_B_W, imm_extended);
140+
asm_str += &format!(" and {}, -2\n", REG_B_W); // clear bit 0 per RISC-V jalr
141+
asm_str += &format!(" mov {}, {}\n", REG_PC_W, REG_B_W); // zero-extend into r13
142+
143+
if write_rd {
144+
let next_pc = pc.wrapping_add(DEFAULT_PC_STEP);
145+
asm_str += &format!(" mov {}, {}\n", REG_A_W, next_pc);
146+
asm_str += &gpr_to_rv32_register(REG_A_W, (a / 4) as u8);
147+
}
148+
149+
asm_str += &format!(" add {}, 1\n", REG_INSTRET);
150+
asm_str += " lea rdx, [rip + map_pc_base]\n";
151+
asm_str += &format!(" movsxd {}, [rdx + {}]\n", REG_A, REG_PC);
152+
asm_str += " add rcx, rdx\n";
153+
asm_str += " jmp rcx\n";
154+
Ok(asm_str)
155+
}
156+
}
95157

96158
impl<F, A> MeteredExecutor<F> for Rv32JalrExecutor<A>
97159
where

extensions/rv32im/circuit/src/jalr/tests.rs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use std::{array, borrow::BorrowMut, sync::Arc};
22

3+
#[cfg(feature = "aot")]
4+
use openvm_circuit::arch::{VmExecutor, VmState};
35
use openvm_circuit::{
46
arch::{
57
testing::{TestBuilder, TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS},
@@ -14,7 +16,11 @@ use openvm_circuit_primitives::{
1416
},
1517
var_range::VariableRangeCheckerChip,
1618
};
19+
#[cfg(feature = "aot")]
20+
use openvm_instructions::{exe::VmExe, program::Program, riscv::RV32_REGISTER_AS, SystemOpcode};
1721
use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode};
22+
#[cfg(feature = "aot")]
23+
use openvm_rv32im_transpiler::BaseAluOpcode::ADD;
1824
use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *};
1925
use openvm_stark_backend::{
2026
p3_air::BaseAir,
@@ -40,6 +46,8 @@ use {
4046
};
4147

4248
use super::Rv32JalrCoreAir;
49+
#[cfg(feature = "aot")]
50+
use crate::Rv32ImConfig;
4351
use crate::{
4452
adapters::{
4553
compose, Rv32JalrAdapterAir, Rv32JalrAdapterExecutor, Rv32JalrAdapterFiller,
@@ -364,6 +372,109 @@ fn run_jalr_sanity_test() {
364372
assert_eq!(rd_data, [252, 36, 14, 47]);
365373
}
366374

375+
#[cfg(feature = "aot")]
376+
fn run_jalr_program(instructions: Vec<Instruction<F>>) -> (VmState<F>, VmState<F>) {
377+
eprintln!("run_jalr_program called");
378+
let program = Program::from_instructions(&instructions);
379+
let exe = VmExe::new(program);
380+
let config = Rv32ImConfig::default();
381+
let memory_dimensions = config.rv32i.system.memory_config.memory_dimensions();
382+
let executor = VmExecutor::new(config.clone()).expect("failed to create Rv32IM executor");
383+
384+
let interpreter = executor
385+
.interpreter_instance(&exe)
386+
.expect("interpreter build must succeed");
387+
let interp_state = interpreter
388+
.execute(vec![], None)
389+
.expect("interpreter execution must succeed");
390+
391+
let mut aot_instance = executor.aot_instance(&exe).expect("AOT build must succeed");
392+
let aot_state = aot_instance
393+
.execute(vec![], None)
394+
.expect("AOT execution must succeed");
395+
396+
/// TODO: add this code to AOT utils file for testing purposes to check equivalence of VMStates
397+
assert_eq!(interp_state.instret(), aot_state.instret());
398+
assert_eq!(interp_state.pc(), aot_state.pc());
399+
use openvm_circuit::{
400+
arch::hasher::poseidon2::vm_poseidon2_hasher, system::memory::merkle::MerkleTree,
401+
};
402+
403+
let hasher = vm_poseidon2_hasher::<BabyBear>();
404+
405+
let tree1 = MerkleTree::from_memory(&interp_state.memory.memory, &memory_dimensions, &hasher);
406+
let tree2 = MerkleTree::from_memory(&aot_state.memory.memory, &memory_dimensions, &hasher);
407+
408+
assert_eq!(tree1.root(), tree2.root(), "Memory states differ");
409+
(interp_state, aot_state)
410+
}
411+
412+
#[cfg(feature = "aot")]
413+
fn read_register(state: &VmState<F>, offset: usize) -> u32 {
414+
let bytes = unsafe { state.memory.read::<u8, 4>(RV32_REGISTER_AS, offset as u32) };
415+
u32::from_le_bytes(bytes)
416+
}
417+
418+
#[cfg(feature = "aot")]
419+
#[test]
420+
fn test_jalr_aot_jump_forward() {
421+
eprintln!("test_jalr_aot_jump_forward called");
422+
let instructions = vec![
423+
Instruction::from_usize(ADD.global_opcode(), [4, 0, 8, RV32_REGISTER_AS as usize, 0]),
424+
Instruction::from_usize(
425+
JALR.global_opcode(),
426+
[0, 4, 0, RV32_REGISTER_AS as usize, 0, 0, 0],
427+
),
428+
Instruction::from_isize(SystemOpcode::TERMINATE.global_opcode(), 0, 0, 0, 0, 0),
429+
];
430+
431+
let (interp_state, aot_state) = run_jalr_program(instructions);
432+
433+
assert_eq!(interp_state.instret(), 3);
434+
assert_eq!(aot_state.instret(), 3);
435+
assert_eq!(interp_state.pc(), 8);
436+
assert_eq!(aot_state.pc(), 8);
437+
438+
let interp_x1 = read_register(&interp_state, 4);
439+
let aot_x1 = read_register(&aot_state, 4);
440+
assert_eq!(interp_x1, 8);
441+
assert_eq!(interp_x1, aot_x1);
442+
}
443+
444+
#[cfg(feature = "aot")]
445+
#[test]
446+
fn test_jalr_aot_writes_return_address() {
447+
eprintln!("test_jalr_aot_writes_return_address called");
448+
let instructions = vec![
449+
Instruction::from_usize(
450+
ADD.global_opcode(),
451+
[4, 0, 12, RV32_REGISTER_AS as usize, 0],
452+
),
453+
Instruction::from_usize(
454+
JALR.global_opcode(),
455+
[12, 4, 0xfffc, RV32_REGISTER_AS as usize, 0, 1, 1],
456+
),
457+
Instruction::from_isize(SystemOpcode::TERMINATE.global_opcode(), 0, 0, 0, 0, 0),
458+
];
459+
460+
let (interp_state, aot_state) = run_jalr_program(instructions);
461+
462+
assert_eq!(interp_state.instret(), 3);
463+
assert_eq!(aot_state.instret(), 3);
464+
assert_eq!(interp_state.pc(), 8);
465+
assert_eq!(aot_state.pc(), 8);
466+
467+
let interp_x1 = read_register(&interp_state, 4);
468+
let aot_x1 = read_register(&aot_state, 4);
469+
assert_eq!(interp_x1, 12);
470+
assert_eq!(interp_x1, aot_x1);
471+
472+
let interp_x3 = read_register(&interp_state, 12);
473+
let aot_x3 = read_register(&aot_state, 12);
474+
assert_eq!(interp_x3, 8);
475+
assert_eq!(interp_x3, aot_x3);
476+
}
477+
367478
// ////////////////////////////////////////////////////////////////////////////////////
368479
// CUDA TESTS
369480
//

extensions/rv32im/circuit/src/mul/execution.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use std::{
55

66
use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
77
use openvm_circuit_primitives_derive::AlignedBytesBorrow;
8+
#[cfg(feature = "aot")]
9+
use openvm_instructions::VmOpcode;
810
use openvm_instructions::{
911
instruction::Instruction,
1012
program::DEFAULT_PC_STEP,
@@ -14,6 +16,8 @@ use openvm_instructions::{
1416
use openvm_rv32im_transpiler::MulOpcode;
1517
use openvm_stark_backend::p3_field::PrimeField32;
1618

19+
#[cfg(feature = "aot")]
20+
use crate::common::{gpr_to_rv32_register, rv32_register_to_gpr};
1721
use crate::MultiplicationExecutor;
1822

1923
#[derive(AlignedBytesBorrow, Clone)]
@@ -48,6 +52,22 @@ impl<A, const LIMB_BITS: usize> MultiplicationExecutor<A, { RV32_REGISTER_NUM_LI
4852
}
4953
}
5054

55+
// Callee saved registers
56+
#[cfg(feature = "aot")]
57+
const REG_PC: &str = "r13";
58+
#[cfg(feature = "aot")]
59+
const REG_INSTRET: &str = "r14";
60+
61+
// Caller saved registers
62+
#[cfg(feature = "aot")]
63+
const REG_A: &str = "rax";
64+
#[cfg(feature = "aot")]
65+
const REG_A_W: &str = "eax";
66+
#[cfg(feature = "aot")]
67+
const REG_B: &str = "rcx";
68+
#[cfg(feature = "aot")]
69+
const REG_B_W: &str = "ecx";
70+
5171
impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
5272
for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
5373
where
@@ -93,6 +113,35 @@ impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
93113
where
94114
F: PrimeField32,
95115
{
116+
fn is_aot_supported(&self, inst: &Instruction<F>) -> bool {
117+
inst.opcode == MulOpcode::MUL.global_opcode()
118+
}
119+
120+
fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
121+
let to_i16 = |c: F| -> i16 {
122+
let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
123+
let c_i24 = ((c_u24 << 8) as i32) >> 8;
124+
c_i24 as i16
125+
};
126+
let a = to_i16(inst.a);
127+
let b = to_i16(inst.b);
128+
let c = to_i16(inst.c);
129+
130+
if a % 4 != 0 || b % 4 != 0 || c % 4 != 0 {
131+
return Err(AotError::InvalidInstruction);
132+
}
133+
134+
let mut asm_str = String::new();
135+
136+
asm_str += &rv32_register_to_gpr((b / 4) as u8, REG_B_W);
137+
asm_str += &rv32_register_to_gpr((c / 4) as u8, REG_A_W);
138+
asm_str += &format!(" imul {}, {}\n", REG_A_W, REG_B_W);
139+
asm_str += &gpr_to_rv32_register(REG_A_W, (a / 4) as u8);
140+
asm_str += &format!(" add {}, 4\n", REG_PC);
141+
asm_str += &format!(" add {}, 1\n", REG_INSTRET);
142+
143+
Ok(asm_str)
144+
}
96145
}
97146

98147
impl<F, A, const LIMB_BITS: usize> MeteredExecutor<F>

0 commit comments

Comments
 (0)