Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/vm/src/arch/aot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub struct AotInstance<'a, F, Ctx> {
pre_compute_buf: AlignedBuf,
lib: Library,
pre_compute_insns_box: Box<[PreComputeInstruction<'a, F, Ctx>]>,
pc_start: u32,
pc_start: u32
}

type AsmRunFn = unsafe extern "C" fn(
Expand Down
64 changes: 63 additions & 1 deletion extensions/rv32im/circuit/src/jalr/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ use openvm_instructions::{
program::{DEFAULT_PC_STEP, PC_BITS},
riscv::RV32_REGISTER_AS,
};
#[cfg(feature = "aot")]
use crate::common::{gpr_to_rv32_register, rv32_register_to_gpr};
#[cfg(feature = "aot")]
use openvm_instructions::{LocalOpcode, VmOpcode};
#[cfg(feature = "aot")]
use openvm_rv32im_transpiler::Rv32JalrOpcode;
use openvm_stark_backend::p3_field::PrimeField32;

use super::core::Rv32JalrExecutor;
Expand Down Expand Up @@ -54,6 +60,19 @@ macro_rules! dispatch {
};
}

#[cfg(feature = "aot")]
const REG_B_W: &str = "eax";
#[cfg(feature = "aot")]
const REG_A_W: &str = "ecx";
#[cfg(feature = "aot")]
const REG_PC: &str = "r13";
#[cfg(feature = "aot")]
const REG_PC_W: &str = "r13d";
#[cfg(feature = "aot")]
const REG_INSTRET: &str = "r14";
#[cfg(feature = "aot")]
const REG_A: &str = "rcx"; // used when building jump address

impl<F, A> InterpreterExecutor<F> for Rv32JalrExecutor<A>
where
F: PrimeField32,
Expand Down Expand Up @@ -90,8 +109,51 @@ where
dispatch!(execute_e1_handler, enabled)
}
}

#[cfg(feature = "aot")]
impl<F, A> AotExecutor<F> for Rv32JalrExecutor<A> where F: PrimeField32 {}
impl<F, A> AotExecutor<F> for Rv32JalrExecutor<A>
where
F: PrimeField32,
{
fn is_aot_supported(&self, inst: &Instruction<F>) -> bool {
inst.opcode == Rv32JalrOpcode::JALR.global_opcode()
}

fn generate_x86_asm(&self, inst: &Instruction<F>, pc: u32) -> Result<String, AotError> {
let mut asm_str = String::new();
let to_i16 = |c: F| -> i16 {
let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
let c_i24 = ((c_u24 << 8) as i32) >> 8;
c_i24 as i16
};
let a = to_i16(inst.a);
let b = to_i16(inst.b);
if a % 4 != 0 || b % 4 != 0 {
return Err(AotError::InvalidInstruction);
}
let imm_extended = inst.c.as_canonical_u32() + inst.g.as_canonical_u32() * 0xffff0000;
let write_rd = !inst.f.is_zero();

asm_str += &rv32_register_to_gpr((b / 4) as u8, REG_B_W);

asm_str += &format!(" add {}, {}\n", REG_B_W, imm_extended);
asm_str += &format!(" and {}, -2\n", REG_B_W); // clear bit 0 per RISC-V jalr
asm_str += &format!(" mov {}, {}\n", REG_PC_W, REG_B_W); // zero-extend into r13

if write_rd {
let next_pc = pc.wrapping_add(DEFAULT_PC_STEP);
asm_str += &format!(" mov {}, {}\n", REG_A_W, next_pc);
asm_str += &gpr_to_rv32_register(REG_A_W, (a / 4) as u8);
}

asm_str += &format!(" add {}, 1\n", REG_INSTRET);
asm_str += " lea rdx, [rip + map_pc_base]\n";
asm_str += &format!(" movsxd {}, [rdx + {}]\n", REG_A, REG_PC);
asm_str += " add rcx, rdx\n";
asm_str += " jmp rcx\n";
Ok(asm_str)
}
}

impl<F, A> MeteredExecutor<F> for Rv32JalrExecutor<A>
where
Expand Down
111 changes: 111 additions & 0 deletions extensions/rv32im/circuit/src/jalr/tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::{array, borrow::BorrowMut, sync::Arc};

#[cfg(feature = "aot")]
use openvm_circuit::arch::{VmExecutor, VmState};
use openvm_circuit::{
arch::{
testing::{TestBuilder, TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS},
Expand All @@ -14,7 +16,11 @@ use openvm_circuit_primitives::{
},
var_range::VariableRangeCheckerChip,
};
#[cfg(feature = "aot")]
use openvm_instructions::{exe::VmExe, program::Program, riscv::RV32_REGISTER_AS, SystemOpcode};
use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode};
#[cfg(feature = "aot")]
use openvm_rv32im_transpiler::BaseAluOpcode::ADD;
use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *};
use openvm_stark_backend::{
p3_air::BaseAir,
Expand All @@ -40,6 +46,8 @@ use {
};

use super::Rv32JalrCoreAir;
#[cfg(feature = "aot")]
use crate::Rv32ImConfig;
use crate::{
adapters::{
compose, Rv32JalrAdapterAir, Rv32JalrAdapterExecutor, Rv32JalrAdapterFiller,
Expand Down Expand Up @@ -364,6 +372,109 @@ fn run_jalr_sanity_test() {
assert_eq!(rd_data, [252, 36, 14, 47]);
}

#[cfg(feature = "aot")]
fn run_jalr_program(instructions: Vec<Instruction<F>>) -> (VmState<F>, VmState<F>) {
eprintln!("run_jalr_program called");
let program = Program::from_instructions(&instructions);
let exe = VmExe::new(program);
let config = Rv32ImConfig::default();
let memory_dimensions = config.rv32i.system.memory_config.memory_dimensions();
let executor = VmExecutor::new(config.clone()).expect("failed to create Rv32IM executor");

let interpreter = executor
.interpreter_instance(&exe)
.expect("interpreter build must succeed");
let interp_state = interpreter
.execute(vec![], None)
.expect("interpreter execution must succeed");

let mut aot_instance = executor.aot_instance(&exe).expect("AOT build must succeed");
let aot_state = aot_instance
.execute(vec![], None)
.expect("AOT execution must succeed");

/// TODO: add this code to AOT utils file for testing purposes to check equivalence of VMStates
assert_eq!(interp_state.instret(), aot_state.instret());
assert_eq!(interp_state.pc(), aot_state.pc());
use openvm_circuit::{
arch::hasher::poseidon2::vm_poseidon2_hasher, system::memory::merkle::MerkleTree,
};

let hasher = vm_poseidon2_hasher::<BabyBear>();

let tree1 = MerkleTree::from_memory(&interp_state.memory.memory, &memory_dimensions, &hasher);
let tree2 = MerkleTree::from_memory(&aot_state.memory.memory, &memory_dimensions, &hasher);

assert_eq!(tree1.root(), tree2.root(), "Memory states differ");
(interp_state, aot_state)
}

#[cfg(feature = "aot")]
fn read_register(state: &VmState<F>, offset: usize) -> u32 {
let bytes = unsafe { state.memory.read::<u8, 4>(RV32_REGISTER_AS, offset as u32) };
u32::from_le_bytes(bytes)
}

#[cfg(feature = "aot")]
#[test]
fn test_jalr_aot_jump_forward() {
eprintln!("test_jalr_aot_jump_forward called");
let instructions = vec![
Instruction::from_usize(ADD.global_opcode(), [4, 0, 8, RV32_REGISTER_AS as usize, 0]),
Instruction::from_usize(
JALR.global_opcode(),
[0, 4, 0, RV32_REGISTER_AS as usize, 0, 0, 0],
),
Instruction::from_isize(SystemOpcode::TERMINATE.global_opcode(), 0, 0, 0, 0, 0),
];

let (interp_state, aot_state) = run_jalr_program(instructions);

assert_eq!(interp_state.instret(), 3);
assert_eq!(aot_state.instret(), 3);
assert_eq!(interp_state.pc(), 8);
assert_eq!(aot_state.pc(), 8);

let interp_x1 = read_register(&interp_state, 4);
let aot_x1 = read_register(&aot_state, 4);
assert_eq!(interp_x1, 8);
assert_eq!(interp_x1, aot_x1);
}

#[cfg(feature = "aot")]
#[test]
fn test_jalr_aot_writes_return_address() {
eprintln!("test_jalr_aot_writes_return_address called");
let instructions = vec![
Instruction::from_usize(
ADD.global_opcode(),
[4, 0, 12, RV32_REGISTER_AS as usize, 0],
),
Instruction::from_usize(
JALR.global_opcode(),
[12, 4, 0xfffc, RV32_REGISTER_AS as usize, 0, 1, 1],
),
Instruction::from_isize(SystemOpcode::TERMINATE.global_opcode(), 0, 0, 0, 0, 0),
];

let (interp_state, aot_state) = run_jalr_program(instructions);

assert_eq!(interp_state.instret(), 3);
assert_eq!(aot_state.instret(), 3);
assert_eq!(interp_state.pc(), 8);
assert_eq!(aot_state.pc(), 8);

let interp_x1 = read_register(&interp_state, 4);
let aot_x1 = read_register(&aot_state, 4);
assert_eq!(interp_x1, 12);
assert_eq!(interp_x1, aot_x1);

let interp_x3 = read_register(&interp_state, 12);
let aot_x3 = read_register(&aot_state, 12);
assert_eq!(interp_x3, 8);
assert_eq!(interp_x3, aot_x3);
}

// ////////////////////////////////////////////////////////////////////////////////////
// CUDA TESTS
//
Expand Down
49 changes: 49 additions & 0 deletions extensions/rv32im/circuit/src/mul/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use std::{

use openvm_circuit::{arch::*, system::memory::online::GuestMemory};
use openvm_circuit_primitives_derive::AlignedBytesBorrow;
#[cfg(feature = "aot")]
use openvm_instructions::VmOpcode;
use openvm_instructions::{
instruction::Instruction,
program::DEFAULT_PC_STEP,
Expand All @@ -14,6 +16,8 @@ use openvm_instructions::{
use openvm_rv32im_transpiler::MulOpcode;
use openvm_stark_backend::p3_field::PrimeField32;

#[cfg(feature = "aot")]
use crate::common::{gpr_to_rv32_register, rv32_register_to_gpr};
use crate::MultiplicationExecutor;

#[derive(AlignedBytesBorrow, Clone)]
Expand Down Expand Up @@ -48,6 +52,22 @@ impl<A, const LIMB_BITS: usize> MultiplicationExecutor<A, { RV32_REGISTER_NUM_LI
}
}

// Callee saved registers
#[cfg(feature = "aot")]
const REG_PC: &str = "r13";
#[cfg(feature = "aot")]
const REG_INSTRET: &str = "r14";

// Caller saved registers
#[cfg(feature = "aot")]
const REG_A: &str = "rax";
#[cfg(feature = "aot")]
const REG_A_W: &str = "eax";
#[cfg(feature = "aot")]
const REG_B: &str = "rcx";
#[cfg(feature = "aot")]
const REG_B_W: &str = "ecx";

impl<F, A, const LIMB_BITS: usize> InterpreterExecutor<F>
for MultiplicationExecutor<A, { RV32_REGISTER_NUM_LIMBS }, LIMB_BITS>
where
Expand Down Expand Up @@ -93,6 +113,35 @@ impl<F, A, const LIMB_BITS: usize> AotExecutor<F>
where
F: PrimeField32,
{
fn is_aot_supported(&self, inst: &Instruction<F>) -> bool {
inst.opcode == MulOpcode::MUL.global_opcode()
}

fn generate_x86_asm(&self, inst: &Instruction<F>, _pc: u32) -> Result<String, AotError> {
let to_i16 = |c: F| -> i16 {
let c_u24 = (c.as_canonical_u64() & 0xFFFFFF) as u32;
let c_i24 = ((c_u24 << 8) as i32) >> 8;
c_i24 as i16
};
let a = to_i16(inst.a);
let b = to_i16(inst.b);
let c = to_i16(inst.c);

if a % 4 != 0 || b % 4 != 0 || c % 4 != 0 {
return Err(AotError::InvalidInstruction);
}

let mut asm_str = String::new();

asm_str += &rv32_register_to_gpr((b / 4) as u8, REG_B_W);
asm_str += &rv32_register_to_gpr((c / 4) as u8, REG_A_W);
asm_str += &format!(" imul {}, {}\n", REG_A_W, REG_B_W);
asm_str += &gpr_to_rv32_register(REG_A_W, (a / 4) as u8);
asm_str += &format!(" add {}, 4\n", REG_PC);
asm_str += &format!(" add {}, 1\n", REG_INSTRET);

Ok(asm_str)
}
}

impl<F, A, const LIMB_BITS: usize> MeteredExecutor<F>
Expand Down
Loading
Loading